Tensorflow 2.0:从回调中访问批次的张量

Posted

技术标签:

【中文标题】Tensorflow 2.0:从回调中访问批次的张量【英文标题】:Tensorflow 2.0: Accessing a batch's tensors from a callback 【发布时间】:2019-11-21 03:09:08 【问题描述】:

我正在使用 Tensorflow 2.0 并尝试编写一个 tf.keras.callbacks.Callback 来读取我的批处理的 model 的输入和输出。

我希望能够覆盖 on_batch_end 并访问 model.inputsmodel.outputs,但它们不是具有我可以访问的值的 EagerTensor。无论如何可以访问批次中涉及的实际张量值吗?

这有很多实际用途,例如将这些张量输出到 Tensorboard 进行调试,或者将它们序列化以用于其他目的。我知道我可以使用model.predict 再次运行整个模型,但这将迫使我通过网络运行每个输入两次(而且我可能还有非确定性数据生成器)。关于如何实现这一点的任何想法?

【问题讨论】:

【参考方案1】:

不,无法在回调中访问输入和输出的实际值。这不仅仅是回调设计目标的一部分。回调只能访问模型、要拟合的参数、纪元数和一些指标值。如您所见,model.input 和 model.output 仅指向符号 KerasTensors,而不是实际值。

要做你想做的事,你可以获取输入,将它与你关心的输出堆叠(可能与 RaggedTensor 一起),然后将其作为模型的额外输出。然后将您的功能实现为仅读取 y_pred 的自定义指标。在你的metric里面,unstack y_pred得到输入和输出,然后可视化/序列化/等等。Metrics

另一种方法可能是实现一个自定义层,该层使用 py_function 在 python 中调用函数。这在严肃的训练期间会非常慢,但在诊断/调试期间可能就足够了。

【讨论】:

以上是关于Tensorflow 2.0:从回调中访问批次的张量的主要内容,如果未能解决你的问题,请参考以下文章

警告:tensorflow:`write_grads` 将在 TensorFlow 2.0 中忽略`TensorBoard` 回调

如何在 TensorFlow 中处理具有可变长度序列的批次?

在 Tensorflow 中将数据拆分为批次进行分类

tensorflow 批次读取文件内的数据,并将顺序随机化处理. --[python]

如何在特定数量的列中找到 tensorflow 数据集批次中的最大值?

从 tensorflow 2.0 中保存的模型运行预测