keras BatchNormalization 之坑

Posted fosen

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了keras BatchNormalization 之坑相关的知识,希望对你有一定的参考价值。

任务简述:最近做一个图像分类的任务, 一开始拿vgg跑一个baseline,输出看起来很正常:

 

技术图片

 

随后,我尝试其他的一些经典的模型架构,比如resnet50, xception,但训练输出显示明显异常:

 

技术图片

val_loss 一直乱蹦,val_acc基本不发生变化。

检查了输入数据没发现问题,因此怀疑是网络构造有问题, 对比了vgg同xception, resnet在使用layer上的异同,认为问题可能出在BN层上,将vgg添加了BN层之后再训练果然翻车。

技术图片

 

翻看keras BN 的源码, 原来keras 的BN层的call函数里面有个默认参数traing, 默认是None。此参数意义如下:

training=False/0, 训练时通过每个batch的移动平均的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化

training=True/1/None,训练时通过当前batch的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化

 

 当training=None时,训练和测试的批归一化方式不一致,导致validation的输出指标翻车。

 

用keras的BN时切记要设置training=False!!!

def build_model():
    Inputs = Input(shape=intput_shape, name=input)
    x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
    x_tmp = Conv2D(64, (3, 3), activation=relu)(x_tmp)
    x_tmp = Conv2D(64, (3, 3), activation=relu)(x_tmp)
    x_tmp = BatchNormalization(x_tmp, training=False)
    x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp)

    x_tmp = Flatten()(x_tmp)
    x_tmp = Dense(128, activation=relu)(x_tmp)
    outputs = Dense(10, activation=softmax)(x_tmp)
    model = Model(Inputs, outputs)
    return model

 

参考:

https://arxiv.org/pdf/1502.03167v3.pdf

https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16

 

以上是关于keras BatchNormalization 之坑的主要内容,如果未能解决你的问题,请参考以下文章

Keras(TF 后端)中的 BatchNormalization 实现 - 激活之前还是之后?

Keras 中的 BatchNormalization 层给出了意想不到的输出值

keras BatchNormalization 之坑

Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?

你能用 BatchNormalization 解释神经网络中的 Keras get_weights() 函数吗?

如何在 Keras 的测试期间使用 Batch Normalization?