训练和验证模式 tensorflow 的 SAME 数据丢失不一致
Posted
技术标签:
【中文标题】训练和验证模式 tensorflow 的 SAME 数据丢失不一致【英文标题】:Inconsistency in loss on SAME data for train and validation modes tensorflow 【发布时间】:2021-02-17 13:56:18 【问题描述】:我正在使用图像实现语义分割模型。作为一种好的做法,我只用一张图像测试了我的训练管道,并试图过度拟合该图像。令我惊讶的是,当使用完全相同的图像进行训练时,损失达到预期的接近 0,但在评估相同的图像时,损失要高得多,并且随着训练的继续而不断上升。因此,当training=False
时,分段输出是垃圾,但是当使用training=True
运行时,效果很好。
为了让任何人都能重现这一点,我采用了官方的segmentation tutorial 并对其进行了一些修改,以便从头开始训练一个卷积网络,并且只有一张图像。该模型非常简单,只是一个带有批归一化和 Relu 的 Conv2D 序列。结果如下
如您所见,loss 和 eval_loss 确实不同,对图像进行推理在训练模式下会得到完美的结果,而在 eval 模式下则是垃圾。
我知道 Batchnormalization 在推理时间的表现不同,因为它使用训练时计算的平均统计数据。尽管如此,由于我们只使用 1 张相同的图像进行训练并在相同的图像中进行评估,这不应该发生,对吧?此外,我在 Pytorch 中使用相同的优化器实现了相同的架构,但这并没有发生。使用 pytorch 进行训练,eval_loss 收敛到训练损失
在这里你可以找到上面提到的https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu 最后还有 Pytorch 实现
【问题讨论】:
【参考方案1】:它必须对 tensorflow 使用的默认值做更多的事情。批标准化有一个参数momentum
,它控制批统计的平均。公式为:moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
如果您在 BatchNorm 层中设置 momentum=0.0
,则平均统计信息应与当前批次(仅 1 张图像)的统计信息完美匹配。如果这样做,您会看到验证损失几乎立即与训练损失匹配。此外,如果您尝试使用 momentum=0.9
(这是 pytorch 中的等效默认值),它会更快地工作和收敛(就像在 pytorch 中一样)。
【讨论】:
以上是关于训练和验证模式 tensorflow 的 SAME 数据丢失不一致的主要内容,如果未能解决你的问题,请参考以下文章
如何将 Tensorflow 数据集 API 与训练和验证集一起使用
使用相同的图在 TensorFlow 中显示训练和验证的准确性
与 Eager 和 Graph 模式无关的 Tensorflow 训练