对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

Posted 秃头小苏

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析相关的知识,希望对你有一定的参考价值。

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

写在前面

Hello,大家好,我是小苏🧒🏽🧒🏽🧒🏽

在前面的文章中,我已经介绍过挺多种GAN网络了,感兴趣的可以关注一下我的专栏:深度学习网络原理与实战 。目前专栏主要更新了GAN系列文章、Transformer系列和语义分割系列文章,都有理论详解和代码实战,文中的讲解都比较通俗易懂,如果你希望丰富这方面的知识,建议你阅读试试,相信你会有蛮不错的收获。🍸🍸🍸

在阅读本篇教程之前,你非常有必要阅读下面两篇文章:

其实啊,我相信大家来看这篇文章的时候,一定是对上文提到的文章有所了解了,因此大家要是觉得自己对GAN和WGAN了解的已经足够透彻了,那么完全没有必要再浪费时间阅读了。如果你还对它们有一些疑惑或者过了很久已经忘了希望回顾一下的话,那么文章[1]和文章[2]获取对你有所帮助。

大家准备好了嘛,我们这就开始准备学习Spectral Normalization啦!🚖🚖🚖

 

Spectral Normalization原理详解

​  首先,让我们简单的回顾一下WGAN。🌞🌞🌞由于原始GAN网络存在训练不稳定的现象,究其本质,是因为它的损失函数实际上是JS散度,而JS散度不会随着两个分布的距离改变而改变(这句不严谨,细节参考WGAN中的描述),这就会导致生成器的梯度会一直不变,从而导致模型训练效果很差。WGAN为了解决原始GAN网络训练不稳定的现象,引入了EM distance代替原有的JS散度,这样的改变会使生成器梯度一直变化,从而使模型得到充分训练。但是WGAN的提出伴随着一个难点,即如何让判别器的参数矩阵满足Lipschitz连续条件。

​  如何解决上述所说的难点呢?在WGAN中,我们采用了一种简单粗暴的方式来满足这一条件,即直接对判别器的权重参数进行剪裁,强制将权重限制在[-c,c]范围内。大家可以动动我们的小脑瓜想想这种权重剪裁的方式有什么样的问题——(滴,揭晓答案🍍🍍🍍)如果权重剪裁的参数c很大,那么任何权重可能都需要很长时间才能达到极限,从而使训练判别器达到最优变得更加困难;如果权重剪裁的参数c很小,这又容易导致梯度消失。因此,如何确定权重剪裁参数c是重要的,同时这也是困难的。WGAN提出之后,又提出了WGAN-GP来实现Lipschitz 连续条件,其主要通过添加一个惩罚项来实现。【关于WGAN-GP我没有做相关教程,如果不明白的可以评论区留言】那么本文提出了一种归一化的手段Spectral Normalization来实现Lipschitz连续条件,这种归一化具体是怎么实现的呢,下面听我慢慢道来。🍻🍻🍻


我们还是来先回顾一下Lipschitz连续条件,如下:

​             ∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ K ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)| \\le K|x_1-x_2| f(x1)f(x2)Kx1x2

这个式子限制了函数 f ( ⋅ ) \\rmf( \\cdot ) f()的导数,即其导数的绝对值小于K, ∣ f ( x 1 ) − f ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ K \\frac|f(x_1)-f(x_2)||x_1-x_2| \\le K x1x2f(x1)f(x2)K。 🍋🍋🍋

本文介绍的Spectral Normalization的K=1,让我们一起来看看怎么实现的吧!!!


  上文提到,WGAN的难点是如何让判别器的参数矩阵满足Lipschitz连续条件。那么我们就从判别器入手和大家唠一唠。实际上,判别器也是由多层卷积神经网络构成的,我们用下式表示第n层网络输出和第n-1层输入的关系:

​             X n = a n ( W n ⋅ X n − 1 + b n ) X_n=a_n(W_n \\cdot X_n-1+b_n) Xn=an(WnXn1+bn)

  其中 a n ( ⋅ ) a_n(\\cdot) an()表示激活函数, W n W_n Wn表示权重参数矩阵。为了方便起见,我们不设置偏置项 b n b_n bn,即 b n = 0 b_n=0 bn=0。那么上式变为:

​             X n = a n ( W n ⋅ X n − 1 ) X_n=a_n(W_n \\cdot X_n-1) Xn=an(WnXn1)

  再为了方便起见🤸🏽‍♂️🤸🏽‍♂️🤸🏽‍♂️,我们设 a n ( ⋅ ) a_n(\\cdot) an(),即激活函数为Relu。Relu函数在大于0时为y=x,小于0时为y=0,函数图像如下图所示:

​  这样的话式 X n = a n ( W n ⋅ X n − 1 ) X_n=a_n(W_n \\cdot X_n-1) Xn=an(WnXn1)可以写成 X n = D n ⋅ W n ⋅ X n − 1 X_n=D_n \\cdot W_n \\cdot X_n-1 Xn=DnWnXn1,其中 D n D_n Dn为对角矩阵。【大家这里能否理解呢?如果我们的输入为正数时,通过Relu函数值是不变的,那么此时 D n D_n Dn对应的对角元素应该为1;如果我们的输入为负数时,通过Relu函数值将变成0,那么此时 D n D_n Dn对应的对角元素应该为0。也就是说我们将 X n X_n Xn改写成 D n ⋅ W n ⋅ X n − 1 D_n \\cdot W_n \\cdot X_n-1 DnWnXn1形式是可行的。】

​  接着我们做一些简单的推理,得到判别器第n层输出和原始输入的关系,如下图所示:

  最后一层的输出 X n X_n Xn即为判别器的输出,接下来我们用 f ( x ) f(x) f(x)表示;原始输入数据 x 0 x_0 x0我们接下来用 x x x表示。则判别器最终输入输出的关系式如下:

​    f ( x ) = D n ⋅ W n ⋅ D n − 1 ⋅ W n − 1 ⋯ D 3 ⋅ W 3 ⋅ D 2 ⋅ W 2 ⋅ D 1 ⋅ W 1 ⋅ x f(x) = D_n \\cdot W_n \\cdot D_n - 1 \\cdot W_n - 1 \\cdots D_3 \\cdot W_3 \\cdot D_2 \\cdot W_2 \\cdot D_1 \\cdot W_1 \\cdot x f(x)=DnWnDn1Wn1D3W3D2W2D1W1x

  上文说到Lipschitz连续条件本质上就是限制函数 f ( ⋅ ) \\rmf( \\cdot ) f()的导数变化范围,其实就是对 f ( x ) f(x) f(x)梯度提出限制,如下:

∣ ∣ ∇ x f ( x ) ∣ ∣ 2 = ∣ ∣ D n ⋅ W n ⋅ D n − 1 ⋅ W n − 1 ⋯ D 3 ⋅ W 3 ⋅ D 2 ⋅ W 2 ⋅ D 1 ⋅ W 1 ∣ ∣ 2 ≤ ∣ ∣ D n ∣ ∣ 2 ⋅ ∣ ∣ W n ∣ ∣ 2 ⋅ ∣ ∣ D n −

以上是关于对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析的主要内容,如果未能解决你的问题,请参考以下文章

GAN 系列的探索与pytorch实现 (数字对抗样本生成)

对抗生成网络GAN系列——GAN原理及手写数字生成小案例

对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例

对抗生成网络GAN系列——CycleGAN简介及图片春冬变换案例

机器学习-白板推导系列(三十一)-生成对抗网络(GAN,Generative Adversarial Network)

学习笔记-李宏毅GAN(生成对抗网络)全系列