WGAN的来龙去脉

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了WGAN的来龙去脉相关的知识,希望对你有一定的参考价值。

参考技术A 渐渐体会到GAN训练难确实是一个让人头疼的问题,一个多月前我曾粗略地了解了一下WGAN,知道这是一个着眼于提高GAN训练稳定性的成果,但后来发现,我对其原理理解得还不是很充足。于是我把WGAN的一作作者Martin Arjovsky在2017年先后参与的三篇相关论文找来看,对WGAN的来龙去脉有了一个更清晰的理解。

这篇论文是WGAN发表前的铺垫,它最大的贡献是从理论上解释了GAN训练不稳定的原因。

人们在应用GAN时经常发现一个现象:不能把Discriminator训练得太好,否则Generator的性能很难提升上去。该文以此为出发点,分析了GAN目标函数的理论缺陷。

在最早提出GAN的论文中,Goodfellow把GAN的目标函数设置为:

他也证明了,固定Generator时,最优的Discriminator是

然后在面对最优Discriminator时,Generator的优化目标就变成了

可以把上述公式简洁地写成JS散度的形式:

也就是说,如果把Discriminator训练到极致,那么整个GAN的训练目标就成了最小化真实数据分布与合成数据分布之间的JS散度。

该文花了大量的篇幅进行数学推导,证明在一般的情况下,上述有关JS散度的目标函数会带来 梯度消失 的问题。也就是说,如果Discriminator训练得太好,Generator就无法得到足够的梯度继续优化,而如果Discriminator训练得太弱,指示作用不显著,同样不能让Generator进行有效的学习。这样一来,Discriminator的训练火候就非常难把控,这就是GAN训练难的根源。

该文还用实验对这一结论进行了验证:让Generator固定,然后从头开始训练Discriminator,绘制出Generator目标函数梯度和训练迭代次数的关系如下。

可以看到,经过25 epochs的训练以后,Generator得到的梯度已经非常小了,出现了明显的梯度消失问题。

Goodfellow提到过可以把Generator的目标函数改为-logD的形式,在实际应用中,人们也发现这个形式更好用,该文把这个技巧称为 the - log D alternative 。此时Generator的梯度是:

该文证明在最优的Discriminator下,这个梯度可以转化为KL散度和JS散度的组合:

该文对这一结论有两点评论:

1. 该公式的第二项意味着最大化真实数据分布和生成数据分布之间的JS散度,也就是让两者差异化更大,这显然违背了最初的优化目标,算是一种缺陷。

2. 同时,第一项的KL散度会被最小化,这会带来严重的 mode dropping 问题。

关于上述第二点,下面补充一点说明。

mode dropping在更多的情况下被称作mode collapse,指的是生成样本只集中于部分的mode从而缺乏多样性的情况。例如,MNIST数据分布一共有10个mode(0到9共10个数字),如果Generator生成的样本几乎只有其中某个数字,那么就是出现了很严重的mode collapse现象。

接下来解释为什么上述的KL散度

会导致mode collapse。借用网上 某博客 的图,真实的数据分布记为P,生成的数据分布记为Q,图的左边表示两个分布的轮廓,右边表示两种KL散度的分布(由于KL散度的不对称性,KL(P||Q)与KL(Q||P)是不同的)。

右图蓝色的曲线代表KL(Q||P),相当于上述的

可以看到,KL(Q||P)会更多地惩罚q(x) > 0而p(x) -> 0的情况(如x = 2附近),也就是惩罚“生成样本质量不佳”的错误;另一方面,当p(x) > 0而q(x) -> 0时(如x = -3附近),KL(Q||P)给出的惩罚几乎是0,表示对“Q未能广泛覆盖P涉及的区域”不在乎。如此一来,为了“安全”起见,最终的Q将谨慎地覆盖P的一小部分区域,即Generator会生成大量高质量却缺乏多样性的样本,这就是mode collapse问题。

另外,通过类似的分析可以知道,KL(P||Q)则会导致Generator生成多样性强却低质量的样本。

除了上述的缺陷,该文还通过数学证明这种-logD的目标函数还存在梯度方差较大的缺陷,导致训练的不稳定。然后同样通过实验直观地验证了这个现象,如下图,在训练的早期(训练了1 epoch和训练了10 epochs),梯度的方差很大,因此对应的曲线看起来比较粗,直到训练了25 epochs以后GAN收敛了才出现方差较小的梯度。

该文通过严谨的理论推导分析了当前GAN训练难的根源:原始的目标函数容易导致梯度消失;改进后的-logD trick虽然解决了梯度消失的问题,然而又带来了mode collapse、梯度不稳定等问题,同样存在理论缺陷。既然深入剖析了问题的根源,该文自然在最后也提出了一个解决方案,然而该方案毕竟不如后来的WGAN那样精巧,因此我把这部分略过了。

这是最早提出WGAN的论文,沿着上篇论文的思路,该文认为需要对“生成分布与真实分布之间的距离”探索一种更合适的度量方法。作者们把眼光转向了 Earth-Mover 距离,简称 EM 距离,又称 Wasserstein 距离。

EM距离的定义为:

解释如下: 是 和 组合起来的所有可能的联合分布的集合,对于每一个可能的联合分布 而言,可以从中采样 得到一个真实样本 和一个生成样本 ,并算出这对样本的距离 ,所以可以计算该联合分布下样本对距离的期望值 。在所有可能的联合分布中能够对这个期望值取到的下界,就定义为EM距离。

Earth-Mover的本意是推土机的意思,这个命名很贴切,因为从直观上理解,EM距离就是在衡量把 Pr 这堆“沙土”“推”到 Pg 这个“位置”所要花费的最小代价,其中的γ就是一种“推土”方案。

该文接下来又通过数学证明,相比JS、KL等距离,EM距离的变化更加敏感,能提供更有意义的梯度,理论上显得更加优越。

作者们自然想到把EM距离用到GAN中。直接求解EM距离是很难做到的,不过可以用一个叫 Kantorovich-Rubinstein duality 的理论把问题转化为:

这个公式的意思是对所有满足 1-Lipschitz 限制的函数 取到 的上界。简单地说,Lipschitz限制规定了一个连续函数的最大局部变动幅度,如K-Lipschitz就是: 。

然后可以用神经网络的方法来解决上述优化问题:

这个神经网络和GAN中的Discriminator非常相似,只存在一些细微的差异,作者把它命名为Critic以便与Discriminator作区分。两者的不同之处在于:

1. Critic最后一层抛弃了sigmoid,因为它输出的是一般意义上的分数,而不像Discriminator输出的是概率。

2. Critic的目标函数没有log项,这是从上面的推导得到的。

3. Critic在每次更新后都要把参数截断在某个范围,即 weight clipping ,这是为了保证上面讲到的 Lipschitz 限制。

4. Critic训练得越好,对Generator的提升更有利,因此可以放心地多训练Critic。

这样的简单修改就是WGAN的核心了,虽然数学证明很复杂,最后的变动却十分简洁。总结出来的WGAN算法为:

GAN与WGAN的对比如下图:

最后,该文用一系列的实验说明了WGAN的几大优越之处:

1. 不再需要纠结如何平衡Generator和Discriminator的训练程度,大大提高了GAN训练的稳定性:Critic(Discriminator)训练得越好,对提升Generator就越有利。

2. 即使网络结构设计得比较简陋,WGAN也能展现出良好的性能,包括避免了mode collapse的现象,体现了出色的鲁棒性。

3. Critic的loss很准确地反映了Generator生成样本的质量,因此可以作为展现GAN训练进度的定性指标。

紧接着上面的工作,这篇论文对刚提出的WGAN做了一点小改进。

作者们发现WGAN有时候也会伴随样本质量低、难以收敛等问题。WGAN为了保证Lipschitz限制,采用了weight clipping的方法,然而这样的方式可能过于简单粗暴了,因此他们认为这是上述问题的罪魁祸首。

具体而言,他们通过简单的实验,发现weight clipping会导致两大问题:模型建模能力弱化,以及梯度爆炸或消失。

他们提出的替代方案是给Critic loss加入 gradient penalty (GP) ,这样,新的网络模型就叫 WGAN-GP 。

GP项的设计逻辑是:当且仅当一个可微函数的梯度范数(gradient norm)在任意处都不超过1时,该函数满足1-Lipschitz条件。至于为什么限制Critic的梯度范数趋向1(two-sided penalty)而不是小于1(one-sided penalty),作者给出的解释是,从理论上最优Critic的梯度范数应当处处接近1,对Lipschitz条件的影响不大,同时从实验中发现two-sided penalty效果比one-sided penalty略好。

另一个值得注意的地方是,用于计算GP的样本 是生成样本和真实样本的线性插值,直接看算法流程更容易理解:

最后,该论文也通过实验说明,WGAN-GP在训练的速度和生成样本的质量上,都略胜WGAN一筹。

WGAN的介绍

参考技术A GAN模型由生成式模型(generative model)和判别式模型(discriminative model)充当,这里以生成图片为例进行说明。它们的功能分别是:

- 生成模型G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

·判别模型D是用来判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。

·不收敛(non-convergence)的问题

·难以训练

      梯度消失(gradient vanishing)

      崩溃问题(mode collapse)

·模型过于自由不可控

参考:

·Goodfellow Ian, Pouget-AbadieJ, Mirza M, et al. Generative adversarialnets[C]//Advances in NeuralInformation Processing Systems. 2014: 2672-2680.

·生成式对抗网络GAN研究进展(二)——原始GAN http://blog.csdn.net/solomon1558/article/details/52549409

主要发展

CGAN(条件生成对抗网络)

针对问题:模型过于自由不可控。

方法:输入更多信息到GAN模型学习,生成更好的样本。

效果:提高生成图像的质量,明确控制图像的某些方面。

参考:【1】Mirza M, Osindero S.Conditional Generative Adversarial Nets[J]. Computer Science, 2014:2672-2680.

【2】生成式对抗网络GAN研究进展(三)——条件GAN

http://blog.csdn.net/solomon1558/article/details/52549409

DCGAN(深度卷积生成对抗网络)

结合了有监督学习的CNN和无监督的GAN

针对问题:GAN训练不稳定,经常生成无意义的输出。

方法:生成模型和判别模型均采用CNN模型,并在结构上做了一些改变。

参考:【1】Radford A, Metz L, Chintala S. Unsupervised representation learningwith deep convolutional generative adversarial networks[J]. arXiv preprintarXiv:1511.06434, 2015.

【2】生成式对抗网络GAN研究进展(五)——Deep Convolutional Generative

Adversarial Nerworks,DCGAN

http://blog.csdn.net/solomon1558/article/details/52573596

WGAN的贡献

原始GAN训练困难的分析

1)训练目标:

x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。

判别器LOSS(最小化):

生成器LOSS(最小化):

2)训练过程:

· 先固定生成器,训练判别器达到最优,然后训练生成器。

· 利用SGD训练判别器达到最优解为:

训练生成器(在判别器最优时):

A、

最终变换形式:即最小化Pr和Pg之间的JS散度。

结论: 由于P_r与P_g几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数log 2,最终导致生成器的梯度(近似)为0,梯度消失。

B、

最终变换形式:

最小化目标分析:

u最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,在数值上则会导致梯度不稳定。

uKL散度会造成两种错误:生成器没能生成真实样本(缺乏多样性)

生成器生成不真实样本(缺乏准确性)

小结: 在原始GAN的(近似)最优判别器下,第一种生成器loss面临梯度消失问题,第二种生成器loss面临优化目标荒谬、梯度不稳定、对多样性与准确性惩罚不平衡导致mode collapse这几个问题。

WGAN的内容

Wasserstein距离:

(EM距离)

相比KL散度、JS散度的优越性:即便两个分布没有重叠,Wasserstein距离仍能反映它们的远近。KL散度和JS散度是突变的,Wasserstein距离却是平滑的,可以提供有意义的梯度。

WGAN形式

对偶问题:

要求函数f的导函数绝对值不超过K的条件下,对所有可能满足条件的f取到上式的上界,然后再除以K。

进而,用该距离做GAN的LOSS函数,可得:

生成器loss函数:

判别器loss函数:

可以表示训练进程中,其数值越小,表示真实分布与生成分布的Wasserstein距离越小,GAN训练得越好。

·判别器所近似的Wasserstein距离与生成器的生成图片质量高度相关

总结:EM距离相对KL散度与JS散度具有优越的平滑特性,理论上可以解决梯度消失问题。在此近似最优判别器下优化生成器使得Wasserstein距离缩小,就能有效拉近生成分布与真实分布。WGAN既解决了训练不稳定的问题,也提供了一个可靠的训练进程指标,而且该指标确实与生成样本的质量高度相关。

参考:

【1】M. Arjovsky, S.Chintala, and L. Bottou. Wasserstein gan. ArXiv,2017.

【2】令人拍案叫绝的Wasserstein GAN https://zhuanlan.zhihu.com/p/25071913

以上是关于WGAN的来龙去脉的主要内容,如果未能解决你的问题,请参考以下文章

深度学习(五十四)图片翻译WGAN实验测试

GAN1-对抗神经网络梳理(GAN,WGAN,WGAN-GP)

WGAN的介绍

wgan pytorch,pyvision, py-faster-rcnn等的安装使用

使用残差网络与wgan制作二次元人物头像

使用残差网络与wgan制作二次元人物头像