ADMM算法在神经网络模型剪枝方面的应用
Posted 夏小悠
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ADMM算法在神经网络模型剪枝方面的应用相关的知识,希望对你有一定的参考价值。
前言
本篇博客记录一下自己根据对论文 GRIM: A General, Real-Time Deep Learning Inference Framework for Mobile Devices based on Fine-Grained Structured Weight Sparsity 中提到的ADMM算法的理解,给出了ADMM算法的推导过程,并在文章的末尾提供了实现的代码。
1. 交替方向乘子法
交替方向乘子法(Alternating Direction Method of Multipliers, ADMM)作为一种求解优化问题的计算框架,适用于求解凸优化问题。ADMM算法的思想根源可以追溯到20世纪50年代,在20世纪八九十年代中期存在大量的文章分析这种方法的性质,但是当时ADMM主要用于解决偏微分方程问题。1970年由 R. Glowinski 和 D. Gabay 等提出的一种适用于可分离凸优化的简单有效方法,并在统计机器学习、数据挖掘和计算机视觉等领域中得到了广泛应用。ADMM算法主要解决带有等式约束的关于两个变量的目标函数的最小化问题,可以看作在增广拉朗格朗日算法基础上发展的算法,混合了对偶上升算法(Dual Ascent)的可分解性和乘子法(Method of Multipliers)的算法优越的收敛性。相对于乘子法,ADMM算法最大的优势在于其能够充分利用目标函数的可分解性,对目标函数中的多变量进行交替优化。在解决大规模问题上,利用ADMM算法可以将原问题的目标函数等价地分解成若干个可求解的子问题,然后并行求解每一个子问题,最后协调子问题的解得到原问题的全局解。1
优化问题
m
i
n
i
m
i
z
e
f
(
x
)
+
g
(
z
)
s
u
b
j
e
c
t
t
o
A
x
+
B
z
=
c
minimize\\ f(x)+g(z) \\\\ subject\\ to\\ Ax+Bz=c
minimize f(x)+g(z)subject to Ax+Bz=c 其中,
x
∈
R
n
,
z
∈
R
m
,
A
∈
R
p
×
n
,
B
∈
R
p
×
m
,
c
∈
R
p
x \\in R^n,z \\in R^m,A \\in R^{p \\times n},B \\in R^{p \\times m},c \\in R^p
x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rp,构造拉格朗日函数为
L
p
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
λ
T
(
A
x
+
B
z
−
c
)
L_p(x,z,\\lambda )=f(x)+g(z)+\\lambda ^{T}(Ax+Bz-c)
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c) 其增广拉格朗日函数(augmented Lagrangian function)为
L
p
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
λ
T
(
A
x
+
B
z
−
c
)
+
ρ
2
∣
∣
A
x
+
B
z
−
c
∣
∣
2
L_p(x,z,\\lambda )=f(x)+g(z)+\\lambda ^{T}(Ax+Bz-c)+ \\frac {\\rho} {2}||Ax+Bz-c||^{2}
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c)+2ρ∣∣Ax+Bz−c∣∣2 对偶上升法迭代更新
(
x
k
+
1
,
z
k
+
1
)
=
a
r
g
m
i
n
x
,
z
L
p
(
x
,
z
,
λ
k
)
λ
k
+
1
=
λ
k
+
ρ
(
A
x
k
+
1
+
B
z
k
+
1
−
c
)
(x^{k+1},z^{k+1})=\\underset {x,z} {argmin\\ } L_p(x,z,\\lambda ^k) \\\\ \\lambda ^{k+1}=\\lambda ^k+\\rho (Ax^{k+1}+Bz^{k+1}-c)
(xk+1,zk+1)=x,zargmin Lp(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c) 交替方向乘子法则是在
(
x
,
z
)
(x,z)
(x,z)一起迭代的基础上将
x
,
z
x,z
x,z分别固定单独交替迭代,即
x
k
+
1
=
a
r
g
m
i
n
x
L
p
(
x
,
z
k
,
λ
k
)
z
k
+
1
=
a
r
g
m
i
n
z
L
p
(
x
k
+
1
,
z
,
λ
k
)
λ
k
+
1
=
λ
k
+
ρ
(
A
x
k
+
1
+
B
z
k
+
1
−
c
)
x^{k+1}=\\underset {x} {argmin\\ }L_p(x,z^k,\\lambda ^k) \\\\ z^{k+1}=\\underset {z} {argmin\\ }L_p(x^{k+1},z,\\lambda ^k) \\\\ \\lambda ^{k+1}=\\lambda ^k+\\rho (Ax^{k+1}+Bz^{k+1}-c)
xk+1=xargmin Lp(x,zk,λk)zk+1=zargmin Lp(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c) 交替方向乘子的另一种等价形式,将残差定义为
r
k
=
A
x
k
+
B
z
k
−
c
r^k=Ax^k+Bz^k-c
rk=Axk+Bzk−c,同时定义
u
k
=
1
ρ
λ
k
u^k=\\frac {1} {\\rho} \\lambda ^k
uk=ρ1λk作为缩放的对偶变量(dual variable),有
(
λ
k
)
T
r
k
+
ρ
2
∣
∣
r
k
∣
∣
2
=
ρ
2
∣
∣
r
k
+
u
k
∣
∣
2
−
ρ
2
∣
∣
u
k
∣
∣
2
(\\lambda ^k)^Tr^k+\\frac {\\rho} {2} ||r^k||^2=\\frac {\\rho} {2}||r^k+u^k||^2-\\frac {\\rho} {2}||u^k||^2
(λk)Trk+2ρ∣∣rk∣∣2=2ρ∣∣rk+u第55篇剪枝算法:通过网络瘦身学习高效卷积网络