Numba 可以与 Tensorflow 一起使用吗?

Posted

技术标签:

【中文标题】Numba 可以与 Tensorflow 一起使用吗?【英文标题】:Can Numba be used with Tensorflow? 【发布时间】:2017-04-29 04:40:36 【问题描述】:

Numba 可以用来编译与 Tensorflow 接口的 Python 代码吗? IE。 Tensorflow 宇宙之外的计算将使用 Numba 运行以提高速度。我还没有找到有关如何执行此操作的任何资源。

【问题讨论】:

看到这个问题github.com/tensorflow/tensorflow/issues/47 您可以使用 session.run 从 TensorFlow 获取 numpy 数组,在它们上运行 numba,然后使用 feed_dict 将结果反馈到 TF 中 【参考方案1】:

我知道这并不能直接回答您的问题,但它可能是一个不错的选择。 Numba 正在使用即时 (JIT) 编译。因此,您可以按照 TensorFlow 官方文档 here 中的说明,了解如何在 TensorFlow 中使用 JIT(但不在 Numba 生态系统中)。

【讨论】:

【参考方案2】:

您可以使用tf.numpy_function 或tf.py_func 包装python 函数并将其用作TensorFlow 操作。这是我使用的一个示例:

@jit
def dice_coeff_nb(y_true, y_pred):
    "Calculates dice coefficient"
    smooth = np.float32(1)
    y_true_f = np.reshape(y_true, [-1])
    y_pred_f = np.reshape(y_pred, [-1])
    intersection = np.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                            np.sum(y_pred_f) + smooth)
    return score

@jit
def dice_loss_nb(y_true, y_pred):
    "Calculates dice loss"
    loss = 1 - dice_coeff_nb(y_true, y_pred)
    return loss

def bce_dice_loss_nb(y_true, y_pred):
    "Adds dice_loss to categorical_crossentropy"
    loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
            tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return loss

然后我在训练一个 tf.keras 模型时使用了这个损失函数:

...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)

【讨论】:

以上是关于Numba 可以与 Tensorflow 一起使用吗?的主要内容,如果未能解决你的问题,请参考以下文章

将 numba.jit 与 scipy.integrate.ode 一起使用

Python numpy:无法将 datetime64[ns] 转换为 datetime64[D](与 Numba 一起使用)

numba初体验

TensorFlow 可以与 Theano 一起安装吗?

cuDNN v6.0 目前可以与 TensorFlow 一起使用吗?

如何嵌套 numba jitclass