WGAN的来龙去脉
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WGAN的来龙去脉相关的知识,希望对你有一定的参考价值。
参考技术A 渐渐体会到GAN训练难确实是一个让人头疼的问题,一个多月前我曾粗略地了解了一下WGAN,知道这是一个着眼于提高GAN训练稳定性的成果,但后来发现,我对其原理理解得还不是很充足。于是我把WGAN的一作作者Martin Arjovsky在2017年先后参与的三篇相关论文找来看,对WGAN的来龙去脉有了一个更清晰的理解。这篇论文是WGAN发表前的铺垫,它最大的贡献是从理论上解释了GAN训练不稳定的原因。
人们在应用GAN时经常发现一个现象:不能把Discriminator训练得太好,否则Generator的性能很难提升上去。该文以此为出发点,分析了GAN目标函数的理论缺陷。
在最早提出GAN的论文中,Goodfellow把GAN的目标函数设置为:
他也证明了,固定Generator时,最优的Discriminator是
然后在面对最优Discriminator时,Generator的优化目标就变成了
可以把上述公式简洁地写成JS散度的形式:
也就是说,如果把Discriminator训练到极致,那么整个GAN的训练目标就成了最小化真实数据分布与合成数据分布之间的JS散度。
该文花了大量的篇幅进行数学推导,证明在一般的情况下,上述有关JS散度的目标函数会带来 梯度消失 的问题。也就是说,如果Discriminator训练得太好,Generator就无法得到足够的梯度继续优化,而如果Discriminator训练得太弱,指示作用不显著,同样不能让Generator进行有效的学习。这样一来,Discriminator的训练火候就非常难把控,这就是GAN训练难的根源。
该文还用实验对这一结论进行了验证:让Generator固定,然后从头开始训练Discriminator,绘制出Generator目标函数梯度和训练迭代次数的关系如下。
可以看到,经过25 epochs的训练以后,Generator得到的梯度已经非常小了,出现了明显的梯度消失问题。
Goodfellow提到过可以把Generator的目标函数改为-logD的形式,在实际应用中,人们也发现这个形式更好用,该文把这个技巧称为 the - log D alternative 。此时Generator的梯度是:
该文证明在最优的Discriminator下,这个梯度可以转化为KL散度和JS散度的组合:
该文对这一结论有两点评论:
1. 该公式的第二项意味着最大化真实数据分布和生成数据分布之间的JS散度,也就是让两者差异化更大,这显然违背了最初的优化目标,算是一种缺陷。
2. 同时,第一项的KL散度会被最小化,这会带来严重的 mode dropping 问题。
关于上述第二点,下面补充一点说明。
mode dropping在更多的情况下被称作mode collapse,指的是生成样本只集中于部分的mode从而缺乏多样性的情况。例如,MNIST数据分布一共有10个mode(0到9共10个数字),如果Generator生成的样本几乎只有其中某个数字,那么就是出现了很严重的mode collapse现象。
接下来解释为什么上述的KL散度
会导致mode collapse。借用网上 某博客 的图,真实的数据分布记为P,生成的数据分布记为Q,图的左边表示两个分布的轮廓,右边表示两种KL散度的分布(由于KL散度的不对称性,KL(P||Q)与KL(Q||P)是不同的)。
右图蓝色的曲线代表KL(Q||P),相当于上述的
可以看到,KL(Q||P)会更多地惩罚q(x) > 0而p(x) -> 0的情况(如x = 2附近),也就是惩罚“生成样本质量不佳”的错误;另一方面,当p(x) > 0而q(x) -> 0时(如x = -3附近),KL(Q||P)给出的惩罚几乎是0,表示对“Q未能广泛覆盖P涉及的区域”不在乎。如此一来,为了“安全”起见,最终的Q将谨慎地覆盖P的一小部分区域,即Generator会生成大量高质量却缺乏多样性的样本,这就是mode collapse问题。
另外,通过类似的分析可以知道,KL(P||Q)则会导致Generator生成多样性强却低质量的样本。
除了上述的缺陷,该文还通过数学证明这种-logD的目标函数还存在梯度方差较大的缺陷,导致训练的不稳定。然后同样通过实验直观地验证了这个现象,如下图,在训练的早期(训练了1 epoch和训练了10 epochs),梯度的方差很大,因此对应的曲线看起来比较粗,直到训练了25 epochs以后GAN收敛了才出现方差较小的梯度。
该文通过严谨的理论推导分析了当前GAN训练难的根源:原始的目标函数容易导致梯度消失;改进后的-logD trick虽然解决了梯度消失的问题,然而又带来了mode collapse、梯度不稳定等问题,同样存在理论缺陷。既然深入剖析了问题的根源,该文自然在最后也提出了一个解决方案,然而该方案毕竟不如后来的WGAN那样精巧,因此我把这部分略过了。
这是最早提出WGAN的论文,沿着上篇论文的思路,该文认为需要对“生成分布与真实分布之间的距离”探索一种更合适的度量方法。作者们把眼光转向了 Earth-Mover 距离,简称 EM 距离,又称 Wasserstein 距离。
EM距离的定义为:
解释如下: 是 和 组合起来的所有可能的联合分布的集合,对于每一个可能的联合分布 而言,可以从中采样 得到一个真实样本 和一个生成样本 ,并算出这对样本的距离 ,所以可以计算该联合分布下样本对距离的期望值 。在所有可能的联合分布中能够对这个期望值取到的下界,就定义为EM距离。
Earth-Mover的本意是推土机的意思,这个命名很贴切,因为从直观上理解,EM距离就是在衡量把 Pr 这堆“沙土”“推”到 Pg 这个“位置”所要花费的最小代价,其中的γ就是一种“推土”方案。
该文接下来又通过数学证明,相比JS、KL等距离,EM距离的变化更加敏感,能提供更有意义的梯度,理论上显得更加优越。
作者们自然想到把EM距离用到GAN中。直接求解EM距离是很难做到的,不过可以用一个叫 Kantorovich-Rubinstein duality 的理论把问题转化为:
这个公式的意思是对所有满足 1-Lipschitz 限制的函数 取到 的上界。简单地说,Lipschitz限制规定了一个连续函数的最大局部变动幅度,如K-Lipschitz就是: 。
然后可以用神经网络的方法来解决上述优化问题:
这个神经网络和GAN中的Discriminator非常相似,只存在一些细微的差异,作者把它命名为Critic以便与Discriminator作区分。两者的不同之处在于:
1. Critic最后一层抛弃了sigmoid,因为它输出的是一般意义上的分数,而不像Discriminator输出的是概率。
2. Critic的目标函数没有log项,这是从上面的推导得到的。
3. Critic在每次更新后都要把参数截断在某个范围,即 weight clipping ,这是为了保证上面讲到的 Lipschitz 限制。
4. Critic训练得越好,对Generator的提升更有利,因此可以放心地多训练Critic。
这样的简单修改就是WGAN的核心了,虽然数学证明很复杂,最后的变动却十分简洁。总结出来的WGAN算法为:
GAN与WGAN的对比如下图:
最后,该文用一系列的实验说明了WGAN的几大优越之处:
1. 不再需要纠结如何平衡Generator和Discriminator的训练程度,大大提高了GAN训练的稳定性:Critic(Discriminator)训练得越好,对提升Generator就越有利。
2. 即使网络结构设计得比较简陋,WGAN也能展现出良好的性能,包括避免了mode collapse的现象,体现了出色的鲁棒性。
3. Critic的loss很准确地反映了Generator生成样本的质量,因此可以作为展现GAN训练进度的定性指标。
紧接着上面的工作,这篇论文对刚提出的WGAN做了一点小改进。
作者们发现WGAN有时候也会伴随样本质量低、难以收敛等问题。WGAN为了保证Lipschitz限制,采用了weight clipping的方法,然而这样的方式可能过于简单粗暴了,因此他们认为这是上述问题的罪魁祸首。
具体而言,他们通过简单的实验,发现weight clipping会导致两大问题:模型建模能力弱化,以及梯度爆炸或消失。
他们提出的替代方案是给Critic loss加入 gradient penalty (GP) ,这样,新的网络模型就叫 WGAN-GP 。
GP项的设计逻辑是:当且仅当一个可微函数的梯度范数(gradient norm)在任意处都不超过1时,该函数满足1-Lipschitz条件。至于为什么限制Critic的梯度范数趋向1(two-sided penalty)而不是小于1(one-sided penalty),作者给出的解释是,从理论上最优Critic的梯度范数应当处处接近1,对Lipschitz条件的影响不大,同时从实验中发现two-sided penalty效果比one-sided penalty略好。
另一个值得注意的地方是,用于计算GP的样本 是生成样本和真实样本的线性插值,直接看算法流程更容易理解:
最后,该论文也通过实验说明,WGAN-GP在训练的速度和生成样本的质量上,都略胜WGAN一筹。
WGAN的介绍
参考技术A GAN模型由生成式模型(generative model)和判别式模型(discriminative model)充当,这里以生成图片为例进行说明。它们的功能分别是:- 生成模型G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
·判别模型D是用来判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。
·不收敛(non-convergence)的问题
·难以训练
梯度消失(gradient vanishing)
崩溃问题(mode collapse)
·模型过于自由不可控
参考:
·Goodfellow Ian, Pouget-AbadieJ, Mirza M, et al. Generative adversarialnets[C]//Advances in NeuralInformation Processing Systems. 2014: 2672-2680.
·生成式对抗网络GAN研究进展(二)——原始GAN http://blog.csdn.net/solomon1558/article/details/52549409
主要发展
CGAN(条件生成对抗网络)
针对问题:模型过于自由不可控。
方法:输入更多信息到GAN模型学习,生成更好的样本。
效果:提高生成图像的质量,明确控制图像的某些方面。
参考:【1】Mirza M, Osindero S.Conditional Generative Adversarial Nets[J]. Computer Science, 2014:2672-2680.
【2】生成式对抗网络GAN研究进展(三)——条件GAN
http://blog.csdn.net/solomon1558/article/details/52549409
DCGAN(深度卷积生成对抗网络)
结合了有监督学习的CNN和无监督的GAN
针对问题:GAN训练不稳定,经常生成无意义的输出。
方法:生成模型和判别模型均采用CNN模型,并在结构上做了一些改变。
参考:【1】Radford A, Metz L, Chintala S. Unsupervised representation learningwith deep convolutional generative adversarial networks[J]. arXiv preprintarXiv:1511.06434, 2015.
【2】生成式对抗网络GAN研究进展(五)——Deep Convolutional Generative
Adversarial Nerworks,DCGAN
http://blog.csdn.net/solomon1558/article/details/52573596
WGAN的贡献
原始GAN训练困难的分析
1)训练目标:
x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
判别器LOSS(最小化):
生成器LOSS(最小化):
2)训练过程:
· 先固定生成器,训练判别器达到最优,然后训练生成器。
· 利用SGD训练判别器达到最优解为:
训练生成器(在判别器最优时):
A、
最终变换形式:即最小化Pr和Pg之间的JS散度。
结论: 由于P_r与P_g几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数log 2,最终导致生成器的梯度(近似)为0,梯度消失。
B、
最终变换形式:
最小化目标分析:
u最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,在数值上则会导致梯度不稳定。
uKL散度会造成两种错误:生成器没能生成真实样本(缺乏多样性)
生成器生成不真实样本(缺乏准确性)
小结: 在原始GAN的(近似)最优判别器下,第一种生成器loss面临梯度消失问题,第二种生成器loss面临优化目标荒谬、梯度不稳定、对多样性与准确性惩罚不平衡导致mode collapse这几个问题。
WGAN的内容
Wasserstein距离:
(EM距离)
相比KL散度、JS散度的优越性:即便两个分布没有重叠,Wasserstein距离仍能反映它们的远近。KL散度和JS散度是突变的,Wasserstein距离却是平滑的,可以提供有意义的梯度。
WGAN形式
对偶问题:
要求函数f的导函数绝对值不超过K的条件下,对所有可能满足条件的f取到上式的上界,然后再除以K。
进而,用该距离做GAN的LOSS函数,可得:
生成器loss函数:
判别器loss函数:
可以表示训练进程中,其数值越小,表示真实分布与生成分布的Wasserstein距离越小,GAN训练得越好。
·判别器所近似的Wasserstein距离与生成器的生成图片质量高度相关
总结:EM距离相对KL散度与JS散度具有优越的平滑特性,理论上可以解决梯度消失问题。在此近似最优判别器下优化生成器使得Wasserstein距离缩小,就能有效拉近生成分布与真实分布。WGAN既解决了训练不稳定的问题,也提供了一个可靠的训练进程指标,而且该指标确实与生成样本的质量高度相关。
参考:
【1】M. Arjovsky, S.Chintala, and L. Bottou. Wasserstein gan. ArXiv,2017.
【2】令人拍案叫绝的Wasserstein GAN https://zhuanlan.zhihu.com/p/25071913
以上是关于WGAN的来龙去脉的主要内容,如果未能解决你的问题,请参考以下文章
GAN1-对抗神经网络梳理(GAN,WGAN,WGAN-GP)