(Batch Normalization)批标准化算法理解

Posted 岳飞传

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了(Batch Normalization)批标准化算法理解相关的知识,希望对你有一定的参考价值。

批标准化

1.概念

batch normalization,就是“批规范化”,即为了克服神经网络层数加深,收敛速度变慢,常常导致梯度消失(vanishing gradient problem)或梯度爆炸(gradient explore),通过引入批标准化来规范某些层或者所有层的输入,从而固定每层输入信号的均值与方差。

2.方法

批标准化: 一般用在非线性映射(激活函数)之前,对 x=Wu+b x = W u + b 做规范化,使结果(输出信号各个维度)的均值为0,方差为1 。让每一层的输入有一个稳定的分布会有利于网络的训练。
BN算法在网络中的应用
传统的神经网络,只是在将样本x输入输入层之前对x进行标准化处理(减均值,除标准差),以降低样本间的差异性。BN是在此基础上,不仅仅只对输入层的输入数据x进行标准化,还对每个隐藏层的输入进行标准化。如下图

标准化后的x乘以权值矩阵 Wh1 W h 1 加上偏置 bh1 b h 1 得到第一层的输入 Wh1x+bh1 W h 1 x + b h 1 ,经过激活函数得到 h1=ReLU(Wh1x+bh1) h 1 = R e L U ( W h 1 x + b h 1 ) ,然而加入BN后, h1的计算流程如虚线框所示:
1. 矩阵x先经过 Wh1 W h 1 的线性变换后得到 s1 s 1 (注:因为减去batch的平均值μB后,b的作用会被抵消掉,所提没必要加入b了)

将s1再减去batch的平均值 μB μ B ,并除以batch的标准差 σ2B+ϵ σ B 2 + ϵ −−−−−√得到s2. ϵ是为了避免除数为0时所使用的微小正数。
其中

μB=1mmi=0Wh1xi μ B = 1 m ∑ i = 0 m W h 1 x i
σ2B=1mmi=0(Wh1xiμB)2 σ B 2 = 1 m ∑ i = 0 m ( W h 1 x i − μ B ) 2
(注:由于这样做后s2基本会被限制在正态分布下,使得网络的表达能力下降。为解决该问题,引入两个新的参数:γ,β. γ和β是在训练时网络自己学习得到的。)

将s2乘以γ调整数值大小,再加上β增加偏移后得到s3
s3经过激活函数后得到h1
需要注意的是,上述的计算方法用于在训练过程中。在测试时,所使用的μ和 σ2 σ 2 是整个训练集的均值 μp μ p 和方差 σ2p σ p 2 . 整个训练集的均值 μp μ p 和方差 σ2p σ p 2 的值通常是在训练的同时用移动平均法来计算的.

3.优缺点

批标准化通过规范化让激活函数分布在线性区间,结果就是加大了梯度,让模型更大胆的进行梯度下降,其优点如下:

  • 加大探索补偿,加快收敛速度
  • 更容易跳出局部最小
  • 破坏原来的数据分布,一定程度上缓解过拟合

4.示例

待续……

参考:
1.http://blog.csdn.net/hjimce/article/details/50866313
2. http://blog.csdn.net/whitesilence/article/details/75667002
3. https://www.zhihu.com/question/38102762
4. 《TensorFlow 技术解析与实战》李嘉璇著
5.《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》
6. 《Spatial Transformer Networks》

以上是关于(Batch Normalization)批标准化算法理解的主要内容,如果未能解决你的问题,请参考以下文章

Batch Normalization批标准化是什么? | BN有啥用 | Batch Normalization是什么

12. 批标准化(Batch Normalization )

莫烦课程Batch Normalization 批标准化

[转] 深入理解Batch Normalization批标准化

[转]深入理解Batch Normalization批标准化

PyTorch学习(十四)Batch_Normalization(批标准化)