使用 @tffunction 的 Tensorflow2 警告
Posted
技术标签:
【中文标题】使用 @tffunction 的 Tensorflow2 警告【英文标题】:Tensorflow2 warning using @tffunction 【发布时间】:2020-03-17 05:25:55 【问题描述】:来自 Tensorflow 2 的示例代码
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()
但它会发出警告。
警告:tensorflow:5 次调用触发的 tf.function 回溯中的 5 次。追踪是昂贵的 并且追踪次数过多很可能是由于通过了python 对象而不是张量。此外, tf.function 有 experimental_relax_shapes=True 放松参数形状的选项 可以避免不必要的回溯。请参阅 https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args 和https://www.tensorflow.org/api_docs/python/tf/function 了解更多 详情。
有没有更好的方法来做到这一点?
【问题讨论】:
【参考方案1】:tf.function
有一些“特殊性”。我强烈推荐阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance
在这种情况下,问题在于每次您使用不同的输入签名调用时,都会“回溯”该函数(即构建一个新图)。对于张量,输入签名指的是 shape 和 dtype,但对于 Python 数字,每个新值都被解释为“不同”。在这种情况下,因为您使用每次都更改的step
变量调用函数,所以每次都会回溯该函数。对于“真实”代码(例如在函数内部调用模型),这将非常慢。
您可以通过简单地将step
转换为张量来修复它,在这种情况下,不同的值将不算作新的输入签名:
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()
或者使用tf.range
直接获取张量:
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()
这应该不会产生警告(而且会更快)。
【讨论】:
如果不使用@tf.function
,这个问题是否也会导致代码执行速度变慢?
不,因为这样就没有要构建/回溯的图表了。
@tf.function
只接受张量而不是缩放器类型。转换为张量很好。以上是关于使用 @tffunction 的 Tensorflow2 警告的主要内容,如果未能解决你的问题,请参考以下文章