Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告
Posted
技术标签:
【中文标题】Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告【英文标题】:Tensorflow 2.0: custom keras metric caused tf.function retracing warning 【发布时间】:2020-03-07 21:27:51 【问题描述】:当我使用以下自定义指标(keras 样式)时:
from sklearn.metrics import classification_report, f1_score
from tensorflow.keras.callbacks import Callback
class Metrics(Callback):
def __init__(self, dev_data, classifier, dataloader):
self.best_f1_score = 0.0
self.dev_data = dev_data
self.classifier = classifier
self.predictor = Predictor(classifier, dataloader)
self.dataloader = dataloader
def on_epoch_end(self, epoch, logs=None):
print("start to evaluate....")
_, preds = self.predictor(self.dev_data)
y_trues, y_preds = [self.dataloader.label_vector(v["label"]) for v in self.dev_data], preds
f1 = f1_score(y_trues, y_preds, average="weighted")
print(classification_report(y_trues, y_preds,
target_names=self.dataloader.vocab.labels))
if f1 > self.best_f1_score:
self.best_f1_score = f1
self.classifier.save_model()
print("best metrics, save model...")
我收到以下警告:
W1106 10:49:14.171694 4745115072 def_function.py:474] 在 0x14a3f9d90> 对 .distributed_function 的最后 11 次调用中有 6 次触发了 tf.function 回溯。跟踪很昂贵,并且过多的跟踪可能是由于传递了 python 对象而不是张量。此外, tf.function 具有 Experimental_relax_shapes=True 选项,可以放宽可以避免不必要的回溯的参数形状。详情请参考https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args和https://www.tensorflow.org/api_docs/python/tf/function。
【问题讨论】:
【参考方案1】:当回溯 TF 函数时会出现此警告,因为其参数的形状或 dtype(对于张量)甚至值(Python 或 np 对象或变量)发生变化。
在一般情况下,解决方法是在定义您传递给 Keras 或 TF 某处的自定义函数之前使用 @tf.function(experimental_relax_shapes=True)。 这会尝试检测并避免不必要的回溯,但不能保证解决问题。
在你的情况下,我猜 Predictor 类是一个自定义类,所以将 @tf.function(experimental_relax_shapes=True) 放在 Predictor.predict() 的定义之前。
【讨论】:
【参考方案2】:导入tensorflow后添加这一行:
tf.compat.v1.disable_eager_execution()
【讨论】:
【参考方案3】:然后使用@tf.function(experimental_relax_shapes=True) 可能会解决您的问题
【讨论】:
装饰器应该加在哪里?以上是关于Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告的主要内容,如果未能解决你的问题,请参考以下文章
如何在 tensorflow 2.0 w/keras 中保存/恢复大型模型?
安装 tensorflow 1.3 后是不是需要单独安装 keras 2.0?
Keras 2.3.0 发布:支持TensorFlow 2.0!!!!!
如何为 keras 模型使用 tensorflow 自定义损失?