为什么会出现Batch Normalization层

Posted NodYoung

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了为什么会出现Batch Normalization层相关的知识,希望对你有一定的参考价值。

训练模型时的收敛速度问题

众所周知,模型训练需要使用高性能的GPU,还要花费大量的训练时间。除了数据量大及模型复杂等硬性因素外,数据分布的不断变化使得我们必须使用较小的学习率、较好的权重初值和不容易饱和的激活函数(如sigmoid,正负两边都会饱和)来训练模型。这样速度自然就慢了下来。

下面先简单示例一下数据分布的不断变化为什么会带来这些问题,如图:

我们使用Wx+b=0对小黄和小绿进行分类。由于数据点仅落在第一象限中很小的区域里,那么如果我们随机初始化权重W,需要迭代很多次才会得到有效的分割,这势必会带来求解速率的下降,并且容易遇到局部最优解。

如果我们遇到的仅仅是这么一个简单的问题,大家用屁股都能想出来应该怎么做。均值归一化呀,把数据挪到原点附近,这样问题就解决了。但是在深度神经网络里,数据分布的不断变化至少来自两方面:a) 每批训练数据的分布各不相同(batch梯度下降),那么网络就要在每次迭代都去学习适应不同的分布。b) 网络中某一层输入数据的分布发生改变,后面几层就会被累积放大下去,这样模型就需要去不断适应学习新的数据分布。论文中把网络中间层在训练过程中,参数不断变化导致的各层输入分布的变化 称为 Internal Covariate Shift。

怎么办呢?这个时候我们可能就会想,如果在每一层输入的时候,再加个预处理操作那该有多好啊。好,BatchNormalization来了。

怎么加入归一化BatchNormalization

  1. 这里加入的BatchNormalization层可不想我们想象的那么简单,它是一个可学习、有参数的网络层。为什么呢?如果我们直接对网络某一层A的输出数据做归一化然后送入网络下一层B,这样会影响到本层所学习到的特征。(比如网络中某一层学习到的特征分布在S型激活函数的某一侧,你强制把它归一化到标准差为1内。)怎么办?论文引入了可学习的参数γ、β(文中把他称为变换重构)来保留其学习到的特征,公式如下:

    接着论文中进一步证明了这一切参数都是可以链式求导的(具体见论文),因此γ、β也就可以像权重W那样不断迭代优化啦。

  2. 另一个问题是BatchNormalization层加在哪?有两个落户地址:a) W x +b 之后,非线性激活函数之前 b) 非线性激活函数后。作者认为前者效果要更好一点,给出的解释是:前一个激活层是非线性输出,其分布很可能在训练中变化,而Wx+b更可能“more Gaussian”。(当然也有人简单尝试了一下Post-activation batch normalization,效果也还不错)

其他一切细节自己看论文吧。

BatchNormalization带来的好处

  1. 可以选择比较大的初始学习率,让你的训练速度飙涨。
  2. 移除或使用较低的drop out、L2正则项参数。
  3. b 可以忽略了,因为b 的作用其实被 β 代替了。

在tensorflow中应用Batch Normalization

A GENTLE GUIDE TO USING BATCH NORMALIZATION IN TENSORFLOW 这篇文章介绍的特别好,并且末尾的GITHUB代码里对全连接网络的No batch normalization、Standard batch normalization和Post-activation batch normalization三种方法进行了实现和对比。贴一张他的实验结果图:

这里我对其Standard batch normalization进行了简单的修改,贴于此:

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


# define our typical fully-connected + batch normalization + nonlinearity set-up
def dense(x, size, scope):
    return tf.contrib.layers.fully_connected(x, size,
                                             activation_fn=None,
                                             scope=scope)


def dense_batch_relu(x, phase, scope):
    with tf.variable_scope(scope):
        h1 = tf.contrib.layers.fully_connected(x, 100,
                                               activation_fn=None,
                                               scope='dense')
        h2 = tf.contrib.layers.batch_norm(h1,
                                          center=True, scale=True,
                                          is_training=phase,
                                          scope='bn')
        return tf.nn.relu(h2, 'relu')


tf.reset_default_graph()
x = tf.placeholder('float32', (None, 784), name='x')
y = tf.placeholder('float32', (None, 10), name='y')
phase = tf.placeholder(tf.bool, name='phase')

h1 = dense_batch_relu(x, phase,'layer1')
h2 = dense_batch_relu(h1, phase, 'layer2')
logits = dense(h2, 10, 'logits')

with tf.name_scope('accuracy'):
    accuracy = tf.reduce_mean(tf.cast(
            tf.equal(tf.argmax(y, 1), tf.argmax(logits, 1)),
            'float32'))

with tf.name_scope('loss'):
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))


def train(mnist):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # Ensures that we execute the update_ops before performing the train_step
        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    history = []
    iterep = 500
    for i in range(iterep * 30):
        x_train, y_train = mnist.train.next_batch(100)
        sess.run(train_step,
                 feed_dict='x:0': x_train,
                            'y:0': y_train,
                            'phase:0': 1)
        if (i + 1) % iterep == 0:
            epoch = (i + 1)/iterep
            tr = sess.run([loss, accuracy],
                          feed_dict='x:0': mnist.train.images,
                                     'y:0': mnist.train.labels,
                                     'phase:0': 1)
            t = sess.run([loss, accuracy],
                         feed_dict='x:0': mnist.test.images,
                                    'y:0': mnist.test.labels,
                                    'phase:0': 0)
            history += [[epoch] + tr + t]
            print(history[-1])
    return history


def main(argv=None):
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    train(mnist)


if __name__ == '__main__':
    tf.app.run()

参考:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
深度学习中 Batch Normalization为什么效果好?
Why does batch normalization help?
深度学习(二十九)Batch Normalization 学习笔记
论文笔记-Batch Normalization
《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》阅读笔记与实现
Implementing Batch Normalization in Tensorflow
A GENTLE GUIDE TO USING BATCH NORMALIZATION IN TENSORFLOW

以上是关于为什么会出现Batch Normalization层的主要内容,如果未能解决你的问题,请参考以下文章

请问batch_normalization做了normalization后为什么要变回来?

Batch Normalization批标准化是什么? | BN有啥用 | Batch Normalization是什么

batch-normalization为什么效果好

BN(Batch Normalization)

12. 批标准化(Batch Normalization )

Batch normalization