Numba 可以与 Tensorflow 一起使用吗?
Posted
技术标签:
【中文标题】Numba 可以与 Tensorflow 一起使用吗?【英文标题】:Can Numba be used with Tensorflow? 【发布时间】:2017-04-29 04:40:36 【问题描述】:Numba 可以用来编译与 Tensorflow 接口的 Python 代码吗? IE。 Tensorflow 宇宙之外的计算将使用 Numba 运行以提高速度。我还没有找到有关如何执行此操作的任何资源。
【问题讨论】:
看到这个问题github.com/tensorflow/tensorflow/issues/47 您可以使用session.run
从 TensorFlow 获取 numpy 数组,在它们上运行 numba,然后使用 feed_dict
将结果反馈到 TF 中
【参考方案1】:
我知道这并不能直接回答您的问题,但它可能是一个不错的选择。 Numba 正在使用即时 (JIT) 编译。因此,您可以按照 TensorFlow 官方文档 here 中的说明,了解如何在 TensorFlow 中使用 JIT(但不在 Numba 生态系统中)。
【讨论】:
【参考方案2】:您可以使用tf.numpy_function 或tf.py_func 包装python 函数并将其用作TensorFlow 操作。这是我使用的一个示例:
@jit
def dice_coeff_nb(y_true, y_pred):
"Calculates dice coefficient"
smooth = np.float32(1)
y_true_f = np.reshape(y_true, [-1])
y_pred_f = np.reshape(y_pred, [-1])
intersection = np.sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (np.sum(y_true_f) +
np.sum(y_pred_f) + smooth)
return score
@jit
def dice_loss_nb(y_true, y_pred):
"Calculates dice loss"
loss = 1 - dice_coeff_nb(y_true, y_pred)
return loss
def bce_dice_loss_nb(y_true, y_pred):
"Adds dice_loss to categorical_crossentropy"
loss = tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
tf.keras.losses.categorical_crossentropy(y_true, y_pred)
return loss
然后我在训练一个 tf.keras 模型时使用了这个损失函数:
...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)
【讨论】:
以上是关于Numba 可以与 Tensorflow 一起使用吗?的主要内容,如果未能解决你的问题,请参考以下文章
将 numba.jit 与 scipy.integrate.ode 一起使用
Python numpy:无法将 datetime64[ns] 转换为 datetime64[D](与 Numba 一起使用)