神经网络权重衰减(weight-decay)
Posted ZSYL
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了神经网络权重衰减(weight-decay)相关的知识,希望对你有一定的参考价值。
权重衰减
上一节中我们观察了过拟合现象,即模型的训练误差远小于它在测试集上的误差。虽然增大训练数据集可能会减轻过拟合,但是获取额外的训练数据往往代价高昂。本节介绍应对过拟合问题的常用方法:权重衰减(weight decay)。
方法
权重衰减等价于 L 2 L_2 L2 范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。我们先描述 L 2 L_2 L2范数正则化,再解释它为何又称权重衰减。
L 2 L_2 L2范数正则化在模型原损失函数基础上添加 L 2 L_2 L2范数惩罚项,从而得到训练所需要最小化的函数。 L 2 L_2 L2范数惩罚项指的是模型权重参数每个元素的平方和与一个正的常数的乘积。以3.1节(线性回归)中的线性回归损失函数
ℓ ( w 1 , w 2 , b ) = 1 n ∑ i = 1 n 1 2 ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) 2 \\ell(w_1, w_2, b) = \\frac1n \\sum_i=1^n \\frac12\\left(x_1^(i) w_1 + x_2^(i) w_2 + b - y^(i)\\right)^2 ℓ(w1,w2,b)=n1i=1∑n21(x1(i)w1+x2(i)w2+b−y(i))2
为例,其中 w 1 , w 2 w_1, w_2 w1,w2是权重参数, b b b是偏差参数,样本 i i i的输入为 x 1 ( i ) , x 2 ( i ) x_1^(i), x_2^(i) x1(i),x2(i),标签为 y ( i ) y^(i) y(i),样本数为 n n n。将权重参数用向量 w = [ w 1 , w 2 ] \\boldsymbolw = [w_1, w_2] w=[w1,w2]表示,带有 L 2 L_2 L2范数惩罚项的新损失函数为
ℓ ( w 1 , w 2 , b ) + λ 2 n ∥ w ∥ 2 , \\ell(w_1, w_2, b) + \\frac\\lambda2n \\|\\boldsymbolw\\|^2, ℓ(w1,w2,b)+2nλ∥w∥2,
其中超参数 λ > 0 \\lambda > 0 λ>0。当权重参数均为0时,惩罚项最小。当 λ \\lambda λ较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素较接近0。当 λ \\lambda λ设为0时,惩罚项完全不起作用。
上式中 L 2 L_2 L2范数平方 ∥ w ∥ 2 \\|\\boldsymbolw\\|^2 ∥w∥2展开后得到 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22。有了 L 2 L_2 L2范数惩罚项后,在小批量随机梯度下降中,我们将线性回归一节中权重 w 1 w_1 w1和 w 2 w_2 w2的迭代方式更改为:
w 1 ← ( 1 − η λ ∣ B ∣ ) w 1 − η ∣ B ∣ ∑ i ∈ B x 1 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) , w 2 ← ( 1 − η λ ∣ B ∣ ) w 2 − η ∣ B ∣ ∑ i ∈ B x 2 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) . \\beginaligned w_1 &\\leftarrow \\left(1- \\frac\\eta\\lambda|\\mathcalB| \\right)w_1 - \\frac\\eta|\\mathcalB| \\sum_i \\in \\mathcalBx_1^(i) \\left(x_1^(i) w_1 + x_2^(i) w_2 + b - y^(i)\\right),\\\\ w_2 &\\leftarrow \\left(1- \\frac\\eta\\lambda|\\mathcalB| \\right)w_2 - \\frac\\eta|\\mathcalB| \\sum_i \\in \\mathcalBx_2^(i) \\left(x_1^(i) w_1 + x_2^(i) w_2 + b - y^(i)\\right). \\endaligned w1w2←(1−∣B∣ηλ)w1−∣B∣ηi∈B∑x1(i)(x1(i)w1+x2(i)w2+b−y(i)),←(1−∣B∣ηλ)w2−∣B∣ηweight decay