keras batchnorm 的测试性能很差
Posted
技术标签:
【中文标题】keras batchnorm 的测试性能很差【英文标题】:keras batchnorm has awful test performance 【发布时间】:2017-02-28 09:44:21 【问题描述】:在对训练数据进行交叉验证期间,使用 batchnorm 可以显着提高性能。但是(在整个训练集上重新训练之后)batchnorm 层的存在完全破坏了模型对保持集的泛化。这有点令人惊讶,我想知道我是否错误地执行了测试预测。
没有batchnorm层的泛化很好(对于我的项目目标来说不够高,但对于这样一个简单的网络来说是合理的)。
我无法共享我的数据,但有人看到明显的实施错误吗?是否有应该设置为测试模式的标志?我在文档中找不到答案,并且 dropout(也应该有不同的训练/测试行为)按预期工作。谢谢!
代码:
from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
from keras.callbacks import ModelCheckpoint
filepath="L1_batch1_weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
init = 'he_normal'
act = 'relu'
neurons1 = 80
dropout_rate = 0.5
model = Sequential()
model.add(Dropout(0.2, input_shape=(5000,)))
model.add(Dense(neurons1))
model.add(BatchNormalization())
model.add(Activation(act))
model.add(Dropout(dropout_rate))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer="adam", metrics=["accuracy"])
my_model = model.fit(X_train, y_train, batch_size=128, nb_epoch=150, validation_data =(X_test, y_test),callbacks=[early_stopping, checkpoint])
model.load_weights("L1_batch1_weights.best.hdf5")
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file")
probs = model.predict_proba(X_test,batch_size=2925)
fpr, tpr, thresholds = roc_curve(y_test, probs)
【问题讨论】:
如果您阅读批处理文件的标签,您会发现它是关于在 Windows 上复制文件等的。 @Noodles,您的评论与问题有何关联? 我编辑了您的问题并删除了批处理文件标签。它是WINDOWS 的shell 脚本语言。与神经网络无关。 包含批处理文件标记是无意的。谢谢。 【参考方案1】:来自docs:“在训练期间,我们使用每批统计数据来标准化数据,在测试期间,我们使用在训练阶段计算的运行平均值。”
在我的案例中,训练批量大小为 128。在测试时,我手动将批量大小设置为完整测试集的大小 (2925)。
有意义的是,用于一个批次大小的统计数据显然与显着不同的批次大小无关。
将测试批次大小更改为训练批次大小 (128) 会产生更稳定的结果。我使用预测批量大小来观察效果:对于任何批量大小 +/- 3 倍的训练批量大小,预测结果都是稳定的,但性能会下降。
这里有一些关于测试批量大小的影响以及与 load_weights() 一起使用时使用 batchnorm 的讨论: https://github.com/fchollet/keras/issues/3423
【讨论】:
推理时的批量大小对推理输出绝对没有影响。以上是关于keras batchnorm 的测试性能很差的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Keras 的测试期间使用 Batch Normalization?
keras 中的 BatchNormalization 是如何工作的?
Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?
tf.keras 模型到 coreml 模型,不支持 BatchNormalization