使用 tf.function 的 Tensorflow 2.0 模型非常慢,并且每次列车数量发生变化时都会重新编译。 Eager 的运行速度提高了大约 4 倍

Posted

技术标签:

【中文标题】使用 tf.function 的 Tensorflow 2.0 模型非常慢,并且每次列车数量发生变化时都会重新编译。 Eager 的运行速度提高了大约 4 倍【英文标题】:Tensorflow 2.0 model using tf.function very slow and is recompiling every time the train count changes. Eager runs about 4x faster 【发布时间】:2019-09-06 17:31:34 【问题描述】:

我有由未编译的 keras 代码构建的模型,并试图通过自定义训练循环运行它们。

TF 2.0 eager(默认)代码在 CPU(笔记本电脑)上运行大约 30 秒。当我使用包装的 tf.function 调用方法创建 keras 模型时,它的运行速度要慢得多,而且似乎需要很长时间才能启动,尤其是“第一次”时间。

例如,在 tf.function 代码中,10 个样本的初始训练需要 40 秒,10 个样本的后续训练需要 2 秒。

在 20 个样本上,初始需要 50 秒,后续需要 4 秒。

第一次训练 1 个样本需要 2 秒,后续需要 200 毫秒。

所以看起来每次调用 train 都在创建一个新图,其中复杂性随列车数量而变化!?

我只是在做这样的事情:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

模型是keras.model.Model,根据示例使用@tf.function 装饰call 方法。

【问题讨论】:

【参考方案1】:

我在这里分析了@tf.function Using a Python native type 的这种行为。

简而言之:tf.function 的设计不会自动将 Python 原生类型装箱到具有明确定义的 dtypetf.Tensor 对象。

如果您的函数接受tf.Tensor 对象,则在第一次调用该函数时会分析该函数,并构建图形并与该函数关联。在每个非第一次调用中,如果 tf.Tensor 对象的 dtype 匹配,则重复使用该图。

但在使用 Python 原生类型的情况下,每次使用不同的值调用函数时都会构建 graphg

简而言之:如果您打算使用 @tf.function,请将您的代码设计为在任何地方都使用 tf.Tensor 而不是 Python 变量。

tf.function 不是一个能够神奇地加速在 Eager 模式下运行良好的函数的包装器;是一个包装器,需要设计渴望函数(主体、输入参数、dytpes)了解创建图形后会发生什么,以获得真正的加速。

【讨论】:

这很棒......我猜应该是文档中的一个大警告。如果有的话,我肯定错过了。 我不明白的一件事是为什么他们在这里有 tf.function 示例tensorflow.org/alpha/guide/effective_tf2 ...如果这是签名的已知问题,则在 args 中包含模型等。跨度> 传递模型(即 keras 对象)、tf.data.dataset 或任何 tf.* 对象根本不是问题。仅当您传递 Python 本机类型时,性能才会降低 我很乐意提供帮助! 谢谢,@nessuno。我知道您指的是原生 numeric 类型,但我还想补充一点,即使列表在 python 中也被视为原生类型,张量列表也可以正常工作。

以上是关于使用 tf.function 的 Tensorflow 2.0 模型非常慢,并且每次列车数量发生变化时都会重新编译。 Eager 的运行速度提高了大约 4 倍的主要内容,如果未能解决你的问题,请参考以下文章

为啥我不能在 @tf.function 中使用 TensorArray.gather()?

使用@tf.function 进行自定义张量流训练的内存泄漏

我应该将@tf.function 用于所有功能吗?

关于TensorFlow2的tf.function()和AutoGraph的一些问题解决

Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告

如何从tensorflow 2.0中的tf.function获取图形?