在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?相关的知识,希望对你有一定的参考价值。
我想使用tf.keras构建自定义图层。为简单起见,假设它应该在训练期间返回输入* 2并在测试期间输入* 3。这样做的正确方法是什么?
我试过这种方法:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training:
return inputs*2
else:
return inputs*3
然后我可以像这样使用这个类:
>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)
它工作正常!但是,当我在模型中使用这个类,并且我调用它的fit()
方法时,似乎training
没有设置为True
。我尝试在call()
方法的开头添加以下代码,但training
始终等于0。
if training is None:
training = K.learning_phase()
我错过了什么?
编辑
我找到了一个解决方案(请参阅我的回答),但我仍在寻找使用@tf.function
的更好的解决方案(我更喜欢亲笔签名到这个smart_cond()
业务)。不幸的是,看起来K.learning_phase()
与@tf.function
不相称(我的猜测是当call()
函数被跟踪时,学习阶段被硬编码到图中:因为这发生在调用fit()
方法之前,学习阶段是总是0)。这可能是一个错误,或者在使用@tf.function
时可能还有另一种方法可以进入学习阶段。
FrançoisChollet确认使用@tf.function
时的正确解决方案是:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
if training:
return inputs * 2
else:
return inputs * 3
目前有一个错误(截至2019年2月15日),使training
总是等于0
,但这很快就会修复。
以下代码不使用@tf.function
,因此它看起来不太好(因为它不使用签名),但它工作正常:
from tensorflow.python.keras.utils.tf_utils import smart_cond
class CustomLayer(Layer):
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
以上是关于在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Tensorflow 2.0 + Keras 中进行并行 GPU 推理?
访问在 TF 2.0 中未显式公开为层的 Keras 模型的中间张量
如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?
如何在 tf.keras 自定义损失函数中触发 python 函数?