在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?相关的知识,希望对你有一定的参考价值。

我想使用tf.keras构建自定义图层。为简单起见,假设它应该在训练期间返回输入* 2并在测试期间输入* 3。这样做的正确方法是什么?

我试过这种方法:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

然后我可以像这样使用这个类:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

它工作正常!但是,当我在模型中使用这个类,并且我调用它的fit()方法时,似乎training没有设置为True。我尝试在call()方法的开头添加以下代码,但training始终等于0。

if training is None:
    training = K.learning_phase()

我错过了什么?

编辑

我找到了一个解决方案(请参阅我的回答),但我仍在寻找使用@tf.function的更好的解决方案(我更喜欢亲笔签名到这个smart_cond()业务)。不幸的是,看起来K.learning_phase()@tf.function不相称(我的猜测是当call()函数被跟踪时,学习阶段被硬编码到图中:因为这发生在调用fit()方法之前,学习阶段是总是0)。这可能是一个错误,或者在使用@tf.function时可能还有另一种方法可以进入学习阶段。

答案

FrançoisChollet确认使用@tf.function时的正确解决方案是:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if training:
            return inputs * 2
        else:
            return inputs * 3

目前有一个错误(截至2019年2月15日),使training总是等于0,但这很快就会修复。

另一答案

以下代码不使用@tf.function,因此它看起来不太好(因为它不使用签名),但它工作正常:

from tensorflow.python.keras.utils.tf_utils import smart_cond

class CustomLayer(Layer):
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)

以上是关于在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow 2.0 + Keras 中进行并行 GPU 推理?

访问在 TF 2.0 中未显式公开为层的 Keras 模型的中间张量

如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?

如何在 tf.keras 自定义损失函数中触发 python 函数?

tf.kerastf.keras使用tensorflow中定义的optimizer

使用 tf.keras.Model.fit 进行训练时如何将自定义摘要添加到 tensorboard