tf-slim 批量规范:训练/推理模式之间的不同行为
Posted
技术标签:
【中文标题】tf-slim 批量规范:训练/推理模式之间的不同行为【英文标题】:tf-slim batch norm: different behaviour between training/inference mode 【发布时间】:2019-01-22 21:32:54 【问题描述】:我正在尝试基于流行的 slim implementation 或 mobilenet_v2
训练一个 tensorflow 模型,并且正在观察我无法解释(我认为)与批量标准化相关的行为。
问题总结
推理模式下的模型性能最初有所提高,但在很长一段时间后开始产生微不足道的推理(全部接近零)。在训练模式下运行时,即使在评估数据集上也能保持良好的性能。评估性能受批标准化衰减/动量率的影响......不知何故。
下面有更广泛的实现细节,但我可能会因为文字墙而失去大部分人,所以这里有一些图片让你感兴趣。
下面的曲线来自我在训练时调整了bn_decay
参数的模型。
0-370k:bn_decay=0.997
(默认)
370k-670k:bn_decay=0.9
670k+:bn_decay=0.5
(橙色)训练(训练模式)和(蓝色)评估(推理模式)的损失。低是好的。
推理模式下评估数据集上模型的评估指标。高是好的。
我试图制作一个演示问题的最小示例 - MNIST 上的分类 - 但失败了(即分类效果很好,我遇到的问题没有展示出来)。对于无法进一步减少事情,我深表歉意。
实现细节
我的问题是 2D 姿态估计,针对以关节位置为中心的高斯。它本质上与语义分割相同,除了使用 tf.losses.l2_loss(sigmoid(logits) - gaussian(label_2d_points))
而不是使用 softmax_cross_entropy_with_logits(labels, logits)
(我使用术语“logits”来描述我的学习模型的未激活输出,尽管这可能不是最好的术语)。
推理模型
在预处理我的输入之后,我的 logits 函数是对基本 mobilenet_v2 的范围调用,后跟一个未激活的卷积层,以使过滤器的数量合适。
from slim.nets.mobilenet import mobilenet_v2
def get_logtis(image):
with mobilenet_v2.training_scope(
is_training=is_training, bn_decay=bn_decay):
base, _ = mobilenet_v2.mobilenet(image, base_only=True)
logits = tf.layers.conv2d(base, n_joints, 1, 1)
return logits
训练操作
我已经尝试过tf.contrib.slim.learning.create_train_op
以及自定义训练操作:
def get_train_op(optimizer, loss):
global_step = tf.train.get_or_create_global_step()
opt_op = optimizer.minimize(loss, global_step)
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
update_ops.add(opt_op)
return tf.group(*update_ops)
我正在使用 tf.train.AdamOptimizer
和 learning rate=1e-3
。
训练循环
我正在使用tf.estimator.Estimator
API 进行培训/评估。
行为
培训最初进展顺利,预期表现会大幅提升。这与我的预期一致,因为最后一层被快速训练以解释预训练基础模型输出的高级特征。
但是,经过长时间(60k 步,batch_size 8,在 GTX-1070 上约 8 小时)后,我的模型在推理模式下运行时开始输出接近零的值 (~1e-11) em>,即is_training=False
。在 *training mode, i.e.
is_training=True` 下运行时,完全相同的模型会继续改进,即使在评估集上也是如此。我已经直观地验证了这一点。
经过一些实验后,我将 bn_decay
(批量标准化衰减/动量率)从默认的 0.997
更改为 0.9
,步长约为 370k(也尝试了 0.99
,但这并没有起到多大作用差异)并观察到准确性的即时改进。在推理模式下对推理的目视检查显示,在预期位置的 ~1e-1
阶推断值中有明显的峰值,与训练模式中峰值的位置一致(尽管值要低得多)。这就是为什么准确率显着提高,但损失 - 虽然更具波动性 - 并没有太大改善。
这些效果在更多的训练后下降并恢复到全零推理。
我进一步将bn_decay
降低到 0.5,步长约为 670k。这导致损失和准确性的改善。我可能要等到明天才能看到长期效果。
下面给出了损失和评估指标图。请注意,评估指标基于 logits 的 argmax,高为佳。损失以实际值为准,低为好。橙色在训练集上使用is_training=True
,而蓝色在评估集上使用is_training=False
。大约 8 的损失与所有零输出一致。
其他说明
我还尝试关闭 dropout(即始终使用is_training=False
运行 dropout 层),并没有发现任何差异。
我已经尝试过从1.7
到1.10
的所有版本的tensorflow。没有区别。
我从一开始就使用bn_decay=0.99
从预训练检查点训练模型。与使用默认 bn_decay
的行为相同。
批量大小为 16 的其他实验导致质量相同的行为(尽管由于内存限制我无法同时评估和训练,因此对批量大小 8 进行定量分析)。
我使用相同的损失和tf.layers
API 训练了不同的模型,并从头开始训练。他们工作得很好。
从头开始训练(而不是使用预训练的检查点)会产生类似的行为,但需要更长的时间。
总结/我的想法:
我相信这不是过度拟合/数据集问题。当使用is_training=True
运行时,该模型会在峰值位置和幅度方面对评估集做出明智的推断。
我相信这不是不运行更新操作的问题。我以前没有使用过slim
,但是除了使用arg_scope
之外,它看起来与我广泛使用的tf.layers
API 并没有太大的不同。我还可以检查移动平均值并观察它们随着训练的进行而变化。
更改bn_decay
值显着暂时影响了结果。我承认 0.5
的值低得离谱,但我的想法已经不多了。
我尝试用momentum=0.997
(即动量与默认衰减值一致)将slim.layers.conv2d
层替换为tf.layers.conv2d
,并且行为是相同的。
使用预训练权重和 Estimator
框架的最小示例可用于 MNIST 分类,无需修改 bn_decay
参数。
我查看了有关 tensorflow 和模型 github 存储库的问题,但除了 this 之外没有发现太多问题。我目前正在尝试使用较低的学习率和更简单的优化器 (MomentumOptimizer
),但这更多是因为我的想法已经用完了,而不是因为我认为这就是问题所在。
可能的解释
我的最佳解释是我的模型参数以某种方式快速循环,以至于移动统计数据无法跟上批量统计数据。我从来没有听说过这种行为,它并不能解释为什么模型会在更多时间后恢复到不良行为,但这是我所拥有的最好的解释。 移动平均代码中可能存在错误,但它在其他所有情况下都非常适合我,包括简单的分类任务。在我能提出一个更简单的例子之前,我不想提出问题。无论如何,我的想法已经不多了,调试周期很长,而且我已经花了太多时间在这上面。很高兴提供更多细节或按需运行实验。也很高兴发布更多代码,但我担心这会吓跑更多人。
提前致谢。
【问题讨论】:
【参考方案1】:使用 Adam 将学习率降低到 1e-4
和使用 Momentum 优化器(使用 learning_rate=1e-3
和 momentum=0.9
)都解决了这个问题。我还发现了this post,这表明该问题跨越了多个框架,并且由于优化器和批处理规范化之间的交互,是某些网络的未记录病态。我不认为这是优化器由于学习率太高而未能找到合适的最小值的简单案例(否则训练模式下的性能会很差)。
我希望这可以帮助遇到同样问题的其他人,但我离满意还有很长的路要走。我很高兴听到其他解释。
【讨论】:
以上是关于tf-slim 批量规范:训练/推理模式之间的不同行为的主要内容,如果未能解决你的问题,请参考以下文章