深度学习中 Internal Covariate Shift 问题以及 Batch Normalization 的作用

Posted 不堪沉沦

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习中 Internal Covariate Shift 问题以及 Batch Normalization 的作用相关的知识,希望对你有一定的参考价值。


前言

提示:这里可以添加本文要记录的大概内容:
最近在看到一篇论文:FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

论文合理的利用了 Batch Norlization 解决了联邦学习中,不同的边缘节点上数据存在的非独立同分布(Non-IID)的问题,降低了 Feature shift 带来的影响。

为此,这里记录对 Batch Normalization 的理解,以及DedBN论文的总结。


一、Batch Normalization是什么?

相信大家第一次接触,应该是在训练模型的时候,模型训练困难,总是波动很大或者损失函数下降十分缓慢,甚至是无法收敛。

有研究表明:随着 DNN 网络层次的加深,参数的变化导致每一层的输入分布会发生改变,进而上层的网络需要不停地去适应这些分布变化,使得我们的模型训练变得困难。

为什么会这样呢?因为 Internal Covariate Shift

1.1 Internal Covariate Shift

因为在训练过程中,每一层的输入分布会随着前一层参数的变化而变化,这种现象称之为 Internal Covariate Shift1

下图为一个多层全连接的神经网络结构示意图,左侧的网络层为底层,右侧的网络层称之为顶层。

以本图为例,每一层 l l l 可以理解为两个操作:

  • 线性变换: Y [ l ] = W [ l ] × i n p u t + b [ l ] Y^{[l]} = W^{[l]} \\times input + b^{[l]} Y[l]=W[l]×input+b[l]。其中 W [ l ] W^{[l]} W[l]表示 l l l层的权重, b [ l ] b^{[l]} b[l]表示 l l l层的偏置, Y [ l ] Y^{[l]} Y[l]表示 l l l层的线性输出
  • 非线性变换: Z [ l ] = g [ l ] ( Y [ l ] ) Z^{[l]} = g^{[l]}(Y^{[l]}) Z[l]=g[l](Y[l]) g [ l ] g^{[l]} g[l]表示 l l l层的激活函数

在模型的反向传播过程中,根据计算的梯度来更新每一层的 W [ l ] W^{[l]} W[l] b [ l ] b^{[l]} b[l],那么 Y [ l ] Y^{[l]} Y[l]的分布也会改变, Z [ l ] Z^{[l]} Z[l]的分布也随之改变。

然而 Z [ l ] Z^{[l]} Z[l]作为下一层 ( l + 1 ) (l+1) (l+1)的输入,这就使得 ( l + 1 ) (l+1) (l+1)层的神经元也需要不断的适应这样的变化,这就会降低整个网络的收敛速度。

1.2 Internal Covariate Shift 带来的影响

① 上层网络需要不停调整来适应输入数据分布的变化,导致网络学习速度的降低

如上所提到的,梯度下降使得每一层的参数都在不断发生变化,
进而使得每一层的线性与非线性计算结果分布产生变化。
后层网络就要不停地去适应这种分布变化,这个时候就会使得整个网络的学习速率过慢。

② 网络的训练过程容易陷入梯度饱和区,减缓网络收敛速度

梯度饱和和梯度消失的后果有点类似(但不要混淆哦)。

梯度饱和:常常是和激活函数相关的,比如sigmod和tanh就属于典型容易进入梯度饱和区的函数。
即自变量进入某个区间后,梯度变化会非常小,
表现在图上就是函数曲线进入某些区域后,越来越趋近一条直线,梯度变化很小。
梯度饱和会导致训练过程中梯度变化缓慢,从而造成模型训练缓慢

下图为 s i g m o i d ( 左 ) sigmoid(左) sigmoid() T a n h ( 右 ) Tanh(右) Tanh() 的激活函数与对应的一阶导数曲线图2


两者的导数均在原点处取得最大值, s i g m o i d sigmoid sigmoid最大为0.25, T a n h Tanh Tanh最大为1;
在远离原点的正负方向上,两者导数均趋近于0,即存在饱和区。

饱和区: 一旦陷入饱和区,两者的偏导都接近于0,导致权重的更新量很小,比如某些权重很大,导致相关的神经元一直陷在饱和区,更新量又接近于0,以致很难跳出或者要花费很长时间才能跳出饱和区。

1.3 如何减缓 Internal Covariate Shift 问题带来的影响

注意:Internal Covariate Shift 是因为参数更新带来的网络中每一层输入值分布的改变,并且随着网络层数的加深而变得更加严重.
因此可以通过固定每一层网络输入值的分布来对减缓ICS问题。

白化(Whitening)

这不是本文探讨的重点,想了解的可以参考 Whitening 3 ,这里我们只需要知道,白化后的数据会有如下性质:

  • 特征之间相关性较低;
  • 所有特征具有相同的方差。

通过白化操作,可以有效地减缓 Internal Covariate Shift 的问题,进而固定了每一层网络输入分布,加速网络训练过程的收敛。

白化存在的问题

白化存在如下缺点:

  • 白化过程计算成本太高,比如 PCA 中,需要计算协方差矩阵,并且在每一轮训练中的每一层我们都需要做如此高成本计算的白化操作;
  • 白化过程由于改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力。底层网络学习到的参数信息会被白化操作丢失掉。

二、Batch Normalization

这里说明下传统 Normalization 使用的原因:

(1)由于神经网络学习过程本质上是为了学习数据的分布,一旦训练数据与测试数据的分布不同,
那么网络的泛化能力也大大降低;
(2)另一方面,在mini-batch梯度下降训练的时候,每批训练数据的分布不相同,那么网络就
要在每次迭代的时候去学习以适应不同的分布,这样将会大大降低网络的训练速度。

而 Normalization 能够很好的使得样本处于同一个分布

2.1 传统 Normalization

上面所说的传统 Normalization 在数学里面都学过,也叫归一化,常见的形式如下:

在非线性变换之后或者线性变换之前对 x x x 进行标准化处理(减去均值,除标准差),让数据处于均值为0、方差为1的分布中,以降低样本间的差异性。而 Batch Normalization 则是指对一个 batch 进行 Normalization .

这样使得对应的样本计算出的梯度处于中间的中心区域(图中红色显示区域)。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,反向传播信息流动性更强,加快训练收敛速度。

如果仅仅使用这样的方法,对网络某一层 l l l 的输出数据做归一化,然后送到网络下一层B中,这样会影响本层网络 l l l 所学到的特征,从而导致数据表达能力的缺失。

另一方面,通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域,即0附近的线性区域。

这样一来非线性激活函数就起不到相应的非线性变换的作用,或者就是相当于一个线性层罢了,那么网络的非线性表达能力就下降了。

2.2 改进

因此,BN又引入了两个可学习(learnable)的参数 γ \\gamma γ β \\beta β,这两个参数的引入是为了恢复数据本身的表达能力,对规范化后的数据进行线性变换,即:

                                                                                                          Z j ~ = γ j Z j ^ + β j \\tilde{Z_{j}} = \\gamma_j\\hat{Z_{j}} + \\beta_j Zj~=γjZj^+βj

这个操作使数据在中心区域附近的线性区域往旁边的非线性区域进行了一定的偏移,即通过 γ j \\gamma_j γj β j \\beta_j βj 把原来的输出值从标准正态分布左移或者右移一点,使得曲线更加胖一点或瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区进行了扩散移动。

核心思想: 找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处( γ j \\gamma_j γj β j \\beta_j βj 带来的),又避免太靠非线性区两头使得网络收敛速度太慢(Normalization 带来的)。这两个参数的核心思想就是兼顾线性的快速收敛,与非线性的较强表达能力。这两个参数需要通过学习得到的。

2.3 原文算法4


三、Batch Normalization 在测试阶段的使用

3.1 测试阶段如何计算 μ l \\mu_l μl σ l 2 \\sigma_l^2 σl2 5

在训练阶段,各层的 μ l \\mu_l μl σ l 2 \\sigma_l^2 σl2 是通过当前层得到的输入 batch 计算而得。

而测试阶段有可能仅输入一个或者极少样本,它对应的 μ l \\mu_l μl σ l 2 \\sigma_l^2 σl2是没有意义的,这时候该如何计算 μ l \\mu_l μl σ l 2 \\sigma_l^2 σl2 呢?

针对每一层 l l l 而言:因为在训练结束后,每一层的参数已经固定好了,那么每一层有很多个已经计算过的mini-batch,则有这些 batch 对应的 μ l \\mu_l μl σ l 2 \\sigma_l^2 σl2,在训练时,把这些值都保存下来;在测试时,通过计算 μ l \\mu_l μl的数学期望,以及 σ l 2 \\sigma_l^2 σl2 的无偏估计,从而间接求出该层的全局统计量: