使用 @tffunction 的 Tensorflow2 警告

Posted

技术标签:

【中文标题】使用 @tffunction 的 Tensorflow2 警告【英文标题】:Tensorflow2 warning using @tffunction 【发布时间】:2020-03-17 05:25:55 【问题描述】:

来自 Tensorflow 2 的示例代码

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

但它会发出警告。

警告:tensorflow:5 次调用触发的 tf.function 回溯中的 5 次。追踪是昂贵的 并且追踪次数过多很可能是由于通过了python 对象而不是张量。此外, tf.function 有 experimental_relax_shapes=True 放松参数形状的选项 可以避免不必要的回溯。请参阅 https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args 和https://www.tensorflow.org/api_docs/python/tf/function 了解更多 详情。

有没有更好的方法来做到这一点?

【问题讨论】:

【参考方案1】:

tf.function 有一些“特殊性”。我强烈推荐阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance

在这种情况下,问题在于每次您使用不同的输入签名调用时,都会“回溯”该函数(即构建一个新图)。对于张量,输入签名指的是 shape 和 dtype,但对于 Python 数字,每个新值都被解释为“不同”。在这种情况下,因为您使用每次都更改的step 变量调用函数,所以每次都会回溯该函数。对于“真实”代码(例如在函数内部调用模型),这将非常慢。

您可以通过简单地将step 转换为张量来修复它,在这种情况下,不同的值将算作新的输入签名:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

或者使用tf.range直接获取张量:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

这应该不会产生警告(而且会更快)。

【讨论】:

如果不使用@tf.function,这个问题是否也会导致代码执行速度变慢? 不,因为这样就没有要构建/回溯的图表了。 @tf.function 只接受张量而不是缩放器类型。转换为张量很好。

以上是关于使用 @tffunction 的 Tensorflow2 警告的主要内容,如果未能解决你的问题,请参考以下文章

使用 tensorflow 我遇到了这样的错误

02.Tensorflow基础用法

ubuntu 17.04 下搭建深度学习环境

构建 TensorFlowLite Swift 自定义框架

Tensorflow 对象检测 API 中的过拟合

(原)torch7中指定可见的GPU