GAN理论推导(知乎转载)
Posted just_sort
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN理论推导(知乎转载)相关的知识,希望对你有一定的参考价值。
GAN理论推导
在知乎上看到一个对GAN推导得十分仔细的文章,写得非常好,我准备按照他的思路推导一下GAN的理论。可以理解为这篇文章转载自:https://zhuanlan.zhihu.com/p/27295635
GAN的原理
首先我们知道真实图片集的分布
P
d
a
t
a
(
x
)
P_data(x)
Pdata(x),x是一个真实的图片,可以想象为一个向量,这个向量集合的分布就是
P
d
a
t
a
P_data
Pdata。我们现在有Generator生成的分布假设为
p
G
(
x
;
θ
)
p_G(x;\\theta)
pG(x;θ),这是一个由
θ
\\theta
θ控制的分布,
θ
\\theta
θ是这个分布的参数(如果是高斯混合模型,那么
θ
\\theta
θ就是每个高斯分布的平均值和方差),假设我们再真实分布中取一些数据,
x
1
,
x
2
,
.
.
.
,
X
m
x^1,x^2,...,X^m
x1,x2,...,Xm,我们想要计算一个似然
P
G
(
x
i
;
θ
)
P_G(x^i;\\theta)
PG(xi;θ),关于似然的理解可以参考这篇博客:https://blog.csdn.net/weixin_40499753/article/details/82977623 对于这些数据,在生成模型中的似然就是
L
=
∏
i
=
1
m
P
G
(
x
i
;
θ
)
L=\\prod_i=1^mP_G(x^i;\\theta)
L=∏i=1mPG(xi;θ), 我们想要最大化这个似然,等价于让generator生成那些真实图片的概率最大,这就变成了一个最大似然估计的问题了,我们需要找到一个参数
θ
∗
\\theta^*
θ∗来最大化这个似然。公式推导如下:
我们寻找一个
θ
∗
\\theta^*
θ∗来最大化这个似然,等价于最大化log似然。因为此时这m个数据是从真实分布中取得,所以也就约等于真实分布中的所有x在
P
G
P_G
PG分布中的log似然的期望。真实分布中的所有x的期望,等价于求概率积分,可以转化为积分运算,因为减号后面的项和
θ
\\theta
θ无关,所以添加上之后还是等价的。然后提出共有的项,括号内的反转,max变为min,就可以转化为KL散度的形式了,KL散度描述的是2个向量之间的差异。所以最大化似然,让generator最大概率的生成真实图片,也就是要找一个
θ
\\theta
θ让
P
G
P_G
PG更接近于
P
d
a
t
a
P_data
Pdata,那如何来找这个最合理的
θ
\\theta
θ呢?我们可以假设
P
G
(
x
;
θ
)
P_G(x;\\theta)
PG(x;θ)是一个神经网络。首先随机一个向量z,通过G(z)=x这个网络生成图片x,那么如何比较两个分布是否相似呢?只要我们取一组sample z,这组z符合一个分布,那么通过网络就可以生成另外一个分布
P
G
P_G
PG,然后来和真实分布
P
d
a
t
a
P_data
Pdata比较。
如何来找更接近的分布,这就是GAN的核心贡献了。GAN的公式为:这个式子的好处在于,固定G,max V(G, D)就表示
P
G
P_G
PG和
P
d
a
t
a
P_data
Pdata之间的差异,然后要找一个最好的G,让这个最大值最小,也就是2个分布之间的差异最小。表面上看这个的意思是,D要让这个式子尽可能的大,也就是对于x是真实分布中,D(x)要接近与1,对于x来自于生成的分布,D(x)要接近于0,然后G要让式子尽可能的小,让来自于生成分布中的x,D(x)尽可能的接近1。
现在我们先固定G,来求解最优的D:
对于一个给定的x,得到最优的D如上图,范围在(0,1)内,把最优的D带入可以得到:
JS divergence是KL divergence的对称平滑版本,表示了两个分布之间的差异,这个推导就表明了上面所说的,固定G,表示两个分布之间的差异,最小值是-2log2,最大值为0。现在我们需要找个G,来最小化观察上式,当时,G是最优的。
训练
有了上面推导的基础之后,我们就可以开始训练GAN了。结合我们开头说的,两个网络交替训练,我们可以在起初有一个
G
0
G_0
G0和
D
0
D_0
D0,先训练
D
0
D_0
D0找到,然后固定
D
0
D_0
D0开始训练
G
0
G_0
G0,训练的过程都可以使用gradient descent,以此类推,训练
D
1
,
G
1
,
D
2
,
G
2
.
.
.
D_1,G_1,D_2,G_2...
D1,G1,D2,G2...
避免上述情况的方法就是更新G的时候,不要更新G太多。
知道了网络的训练顺序,我们还需要设定两个loss function,一个是D的loss,一个是G的loss。下面是整个GAN的训练具体步骤:
上述步骤在机器学习和深度学习中也是非常常见,易于理解。
存在的问题
但是上面G的loss function还是有一点小问题,下图是两个函数的图像:
l
o
g
(
1
−
D
(
x
)
)
log(1-D(x))
log(1−D(x))是我们计算时G的loss function,但是我们发现,在D(x)接近于0的时候,这个函数十分平滑,梯度非常的小。这就会导致,在训练的初期,G想要骗过D,变化十分的缓慢,而上面的函数,趋势和下面的是一样的,都是递减的。但是它的优势是在D(x)接近0的时候,梯度很大,有利于训练,在D(x)越来越大之后,梯度减小,这也很符合实际,在初期应该训练速度更快,到后期速度减慢。
还有可能的原因是,虽然两个分布都是高维的,但是两个分布都十分的窄,可能交集相当小,这样也会导致JS divergence算出来=log2,约等于没有交集。解决的一些方法,有添加噪声,让两个分布变得更宽,可能可以增大它们的交集,这样JS divergence就可以计算,但是随着时间变化,噪声需要逐渐变小。
还有一个问题叫Mode Collapse,如下图:
这个图的意思是,data的分布是一个双峰的,但是学习到的生成分布却只有单峰,我们可以看到模型学到的数据,但是却不知道它没有学到的分布。
造成这个情况的原因是,KL divergence里的两个分布写反了,
以上是关于GAN理论推导(知乎转载)的主要内容,如果未能解决你的问题,请参考以下文章
LOOPS HDU - 3853 (概率dp):(希望通过该文章梳理自己的式子推导)