Batch Normlization原理

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Batch Normlization原理相关的知识,希望对你有一定的参考价值。

参考技术A

在深度学习中,由于问题的复杂性,我们往往会使用较深层数的网络进行训练,相信很多炼丹的朋友都对调参的困难有所体会,尤其是对深层神经网络的训练调参更是困难且复杂。

在这个过程中,我们需要去尝试不同的学习率、初始化参数方法(例如Xavier初始化)等方式来帮助我们的模型加速收敛。深度神经网络之所以如此难训练,其中一个重要原因就是网络中层与层之间存在高度的关联性与耦合性。下图是一个多层的神经网络,层与层之间采用全连接的方式进行连接。

我们规定左侧为神经网络的底层,右侧为神经网络的上层。那么网络中层与层之间的关联性会导致如下的状况:随着训练的进行,网络中的参数也随着梯度下降在不停更新。

Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义: 在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift

这句话怎么理解呢?我们定义每一层的线性变换为 ,其中 代表层数;非线性变换为 ,其中, 为 第 层的激活函数。

醉着梯度下降的进行,每一层的参数 与 都会被更新,那么 的分布也就发生了改变,进而 也同样出现分布的改变。而 作为第 层的输入,意味着 层就需要去不停适应这种数据分布的变化,这一过程就被叫做Internal Covariate Shift。

(1)上层网络需要不停地调整来适应输入数据分布的变化,导致网络学习速度的降低

我们在上面提到梯度下降的过程会让每一层的参数 和 发生变化,进而使得每一层的线性与非线性计算结果分布产生变化。后层网络就要不停地适应这种分布的变化,这个时候就会使得整个网络的学习速率变慢。

(2)网络的训练过程容易陷入梯度饱和区,减缓网络收敛的速度

当我们在神经网络中采用饱和激活函数(saturated activation function)时,例如sigmoid,tanh激活函数,很容易使得模型训练陷入梯度饱和区(saturated regime)。随着模型训练的进行,我们的参数 会逐渐更新并变大,此时 就会随着变大,并且 还要收到更底层网络参数 的影响,随着网络层数的增加, 很容易陷入梯度饱和区,此时梯度会变得很小甚至接近与0,参数的更新速度就会变慢,进而就会放慢网络的收敛速度。

对于激活函数梯度饱和的问题,有两种解决思路:第一种就是使用ReLU等非线性激活函数,可以一定程度上解决训练陷入梯度饱和区的问题。另一种就是,我们可以让激活函数的分布保持在一个稳定的状态,来尽可能避免它们陷入梯度饱和区,也就是Normalization的思路。

要缓解ICS的问题,就要明白它产生的原因。ICS产生的原因是由于参数更新带来的网络中每一层输入值分布的改变,并且随着网络层数的加深而变得更加严重,因此我们可以通过固定每一层网络输入值的分布来对减缓ICS问题。

(1)白化

白化(Whitening)是机器学习里面常用的一种规范化数据分布的方法,主要是PCA白化与ZCA白化。白化是对输入数据分布进行变换,进而达到以下两个目的:

(2)Batch Normalization提出

既然白化可以解决这个问题,为什么我们还要提出别的解决办法?当然是现有的方法具有一定的缺陷,白化主要有以下两个问题:

既然有了上面两个问题,那我们的解决思路就很简单,一方面,我们提出的normalization方法要能够简化计算过程;另一方面又需要经过规范化处理后让数据尽可能保留原始的表达能力。于是就有了简化+改进版的白化——Batch Normalization。

既然白化计算过程比较复杂,那我们就简化一点,比如我们可以尝试单独对每个特征进行normalizaiton就可以了,让每个特征都有均值为0,方差为1的分布就OK。

另一个问题,既然白化操作减弱了网络中每一层输入数据表达能力,那我就再加个线性变换操作,让这些数据再能够尽可能恢复本身的表达能力就好了。

因此,基于上面两个解决问题的思路,作者提出了Batch Normalization,下一部分来具体讲解这个算法步骤。

举例计算:

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程。

假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel数均为2,那么 代表该batch的所有feature的channel1的数据,同理 代表该batch的所有feature的channel2的数据。

然后分别计算 和 的均值和方差,得到我们的 和 两个向量。

然后在根据标准差计算公式分别计算每个channel的值(公式中的 是一个很小的常量,防止分母为零的情况)

在我们训练网络的过程中,我们是通过一个batch一个batch的数据进行训练的,但是我们在预测过程中通常都是输入一张图片进行预测,此时batch size为1,如果在通过上述方法计算均值和方差就没有意义了。所以我们在训练过程中要去不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在我们训练完后我们可以近似认为我们所统计的均值和方差就等于我们整个训练集的均值和方差。然后在我们验证以及预测过程中,就使用我们统计得到的均值和方差进行标准化处理。

在训练过程中,均值 和方差 通过计算当前批次数据得到的记为 和 ,而我们在预测过程中所使用的均值和方差是一个训练过程中保存的统计量,记 和 , 和 的具体更新策略如下,momentum默认取值为0.1:

需要注意的是:

下面是使用pytorch做的测试:

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。
(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化 , 。

设置一个断点进行调试,查看下官方bn对feature处理后得到的统计均值和方差。我们可以发现官方提供的bn的running_mean和running_var和我们自己计算的calculate_mean和calculate_var是一模一样的(只是精度不同):

输出结果如下:

从结果可以看出:通过自定义bn_process函数得到的输出以及使用官方bn处理得到输出,明显结果是一样的(只是精度不同)。

Batch Normalization在实际工程中被证明了能够缓解神经网络难以训练的问题,BN具有的有事可以总结为以下四点:

(1)BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度

BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。

(2)BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定

在神经网络中,我们经常会谨慎地采用一些权重初始化方法(例如Xavier)或者合适的学习率来保证网络稳定训练。
当学习率设置太高时,会使得参数更新步伐过大,容易出现震荡和不收敛。但是使用BN的网络将不会受到参数数值大小的影响。
因此,在使用Batch Normalization之后,抑制了参数微小变化随着网络层数加深被放大的问题,使得网络对参数大小的适应能力更强,此时我们可以设置较大的学习率而不用过于担心模型divergence的风险。

3)BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

在不使用BN层的时候,由于网络的深度与复杂性,很容易使得底层网络变化累积到上层网络中,导致模型的训练很容易进入到激活函数的梯度饱和区;通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习 与 又让数据保留更多的原始信息。

(4)BN具有一定的正则化效果

在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音,与Dropout通过关闭神经元给网络训练带来噪音类似,在一定程度上对模型起到了正则化的效果。

Batch Normalization原理与实战
Batch Normalization详解以及pytorch实验

以上是关于Batch Normlization原理的主要内容,如果未能解决你的问题,请参考以下文章

Batch Normalization 原理

Batch Normalization原理介绍

Cesium原理篇:Batch

Spring boot Batch 的启动原理- Configuration

BN(Batch Normalization) 原理与使用过程详解

深度学习原理与框架- batch_normalize(归一化操作)