GAN理论推导(知乎转载)

Posted just_sort

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN理论推导(知乎转载)相关的知识,希望对你有一定的参考价值。

GAN理论推导

在知乎上看到一个对GAN推导得十分仔细的文章,写得非常好,我准备按照他的思路推导一下GAN的理论。可以理解为这篇文章转载自:https://zhuanlan.zhihu.com/p/27295635

GAN的原理

首先我们知道真实图片集的分布 P d a t a ( x ) P_data(x) Pdata(x),x是一个真实的图片,可以想象为一个向量,这个向量集合的分布就是 P d a t a P_data Pdata。我们现在有Generator生成的分布假设为 p G ( x ; θ ) p_G(x;\\theta) pG(x;θ),这是一个由 θ \\theta θ控制的分布, θ \\theta θ是这个分布的参数(如果是高斯混合模型,那么 θ \\theta θ就是每个高斯分布的平均值和方差),假设我们再真实分布中取一些数据, x 1 , x 2 , . . . , X m x^1,x^2,...,X^m x1,x2,...,Xm,我们想要计算一个似然 P G ( x i ; θ ) P_G(x^i;\\theta) PG(xi;θ),关于似然的理解可以参考这篇博客:https://blog.csdn.net/weixin_40499753/article/details/82977623 对于这些数据,在生成模型中的似然就是 L = ∏ i = 1 m P G ( x i ; θ ) L=\\prod_i=1^mP_G(x^i;\\theta) L=i=1mPG(xi;θ), 我们想要最大化这个似然,等价于让generator生成那些真实图片的概率最大,这就变成了一个最大似然估计的问题了,我们需要找到一个参数 θ ∗ \\theta^* θ来最大化这个似然。公式推导如下:
我们寻找一个 θ ∗ \\theta^* θ来最大化这个似然,等价于最大化log似然。因为此时这m个数据是从真实分布中取得,所以也就约等于真实分布中的所有x在 P G P_G PG分布中的log似然的期望。真实分布中的所有x的期望,等价于求概率积分,可以转化为积分运算,因为减号后面的项和 θ \\theta θ无关,所以添加上之后还是等价的。然后提出共有的项,括号内的反转,max变为min,就可以转化为KL散度的形式了,KL散度描述的是2个向量之间的差异。所以最大化似然,让generator最大概率的生成真实图片,也就是要找一个 θ \\theta θ P G P_G PG更接近于 P d a t a P_data Pdata,那如何来找这个最合理的 θ \\theta θ呢?我们可以假设 P G ( x ; θ ) P_G(x;\\theta) PG(x;θ)是一个神经网络。首先随机一个向量z,通过G(z)=x这个网络生成图片x,那么如何比较两个分布是否相似呢?只要我们取一组sample z,这组z符合一个分布,那么通过网络就可以生成另外一个分布 P G P_G PG,然后来和真实分布 P d a t a P_data Pdata比较。
如何来找更接近的分布,这就是GAN的核心贡献了。GAN的公式为:这个式子的好处在于,固定G,max V(G, D)就表示 P G P_G PG P d a t a P_data Pdata之间的差异,然后要找一个最好的G,让这个最大值最小,也就是2个分布之间的差异最小。表面上看这个的意思是,D要让这个式子尽可能的大,也就是对于x是真实分布中,D(x)要接近与1,对于x来自于生成的分布,D(x)要接近于0,然后G要让式子尽可能的小,让来自于生成分布中的x,D(x)尽可能的接近1。
现在我们先固定G,来求解最优的D:
对于一个给定的x,得到最优的D如上图,范围在(0,1)内,把最优的D带入可以得到:
JS divergence是KL divergence的对称平滑版本,表示了两个分布之间的差异,这个推导就表明了上面所说的,固定G,表示两个分布之间的差异,最小值是-2log2,最大值为0。现在我们需要找个G,来最小化观察上式,当时,G是最优的。

训练

有了上面推导的基础之后,我们就可以开始训练GAN了。结合我们开头说的,两个网络交替训练,我们可以在起初有一个 G 0 G_0 G0 D 0 D_0 D0,先训练 D 0 D_0 D0找到,然后固定 D 0 D_0 D0开始训练 G 0 G_0 G0,训练的过程都可以使用gradient descent,以此类推,训练 D 1 , G 1 , D 2 , G 2 . . . D_1,G_1,D_2,G_2... D1,G1,D2,G2...
避免上述情况的方法就是更新G的时候,不要更新G太多。

知道了网络的训练顺序,我们还需要设定两个loss function,一个是D的loss,一个是G的loss。下面是整个GAN的训练具体步骤:
上述步骤在机器学习和深度学习中也是非常常见,易于理解。

存在的问题

但是上面G的loss function还是有一点小问题,下图是两个函数的图像:
l o g ( 1 − D ( x ) ) log(1-D(x)) log(1D(x))是我们计算时G的loss function,但是我们发现,在D(x)接近于0的时候,这个函数十分平滑,梯度非常的小。这就会导致,在训练的初期,G想要骗过D,变化十分的缓慢,而上面的函数,趋势和下面的是一样的,都是递减的。但是它的优势是在D(x)接近0的时候,梯度很大,有利于训练,在D(x)越来越大之后,梯度减小,这也很符合实际,在初期应该训练速度更快,到后期速度减慢。
还有可能的原因是,虽然两个分布都是高维的,但是两个分布都十分的窄,可能交集相当小,这样也会导致JS divergence算出来=log2,约等于没有交集。解决的一些方法,有添加噪声,让两个分布变得更宽,可能可以增大它们的交集,这样JS divergence就可以计算,但是随着时间变化,噪声需要逐渐变小。
还有一个问题叫Mode Collapse,如下图:
这个图的意思是,data的分布是一个双峰的,但是学习到的生成分布却只有单峰,我们可以看到模型学到的数据,但是却不知道它没有学到的分布。

造成这个情况的原因是,KL divergence里的两个分布写反了,

以上是关于GAN理论推导(知乎转载)的主要内容,如果未能解决你的问题,请参考以下文章

生成对抗网络GAN

LOOPS HDU - 3853 (概率dp):(希望通过该文章梳理自己的式子推导)

机器学习-白板推导系列(三十一)-生成对抗网络(GAN,Generative Adversarial Network)

转载:Logistic回归原理及公式推导

AUC 理论推导

教你编写第一个生成式对抗网络GAN