Tensorflow 2.0:从回调中访问批次的张量
Posted
技术标签:
【中文标题】Tensorflow 2.0:从回调中访问批次的张量【英文标题】:Tensorflow 2.0: Accessing a batch's tensors from a callback 【发布时间】:2019-11-21 03:09:08 【问题描述】:我正在使用 Tensorflow 2.0 并尝试编写一个 tf.keras.callbacks.Callback
来读取我的批处理的 model
的输入和输出。
我希望能够覆盖 on_batch_end
并访问 model.inputs
和 model.outputs
,但它们不是具有我可以访问的值的 EagerTensor
。无论如何可以访问批次中涉及的实际张量值吗?
这有很多实际用途,例如将这些张量输出到 Tensorboard 进行调试,或者将它们序列化以用于其他目的。我知道我可以使用model.predict
再次运行整个模型,但这将迫使我通过网络运行每个输入两次(而且我可能还有非确定性数据生成器)。关于如何实现这一点的任何想法?
【问题讨论】:
【参考方案1】:不,无法在回调中访问输入和输出的实际值。这不仅仅是回调设计目标的一部分。回调只能访问模型、要拟合的参数、纪元数和一些指标值。如您所见,model.input 和 model.output 仅指向符号 KerasTensors,而不是实际值。
要做你想做的事,你可以获取输入,将它与你关心的输出堆叠(可能与 RaggedTensor 一起),然后将其作为模型的额外输出。然后将您的功能实现为仅读取 y_pred 的自定义指标。在你的metric里面,unstack y_pred得到输入和输出,然后可视化/序列化/等等。Metrics
另一种方法可能是实现一个自定义层,该层使用 py_function 在 python 中调用函数。这在严肃的训练期间会非常慢,但在诊断/调试期间可能就足够了。
【讨论】:
以上是关于Tensorflow 2.0:从回调中访问批次的张量的主要内容,如果未能解决你的问题,请参考以下文章
警告:tensorflow:`write_grads` 将在 TensorFlow 2.0 中忽略`TensorBoard` 回调
如何在 TensorFlow 中处理具有可变长度序列的批次?
tensorflow 批次读取文件内的数据,并将顺序随机化处理. --[python]