生成对抗网络GAN

Posted zhiyong_will

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了生成对抗网络GAN相关的知识,希望对你有一定的参考价值。

1. 概述

生成对抗网络GAN(Generative adversarial nets)[1]是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator) G G G和判别网络(Discriminator) D D D,为描述简单,以图像生成为例:

  • 生成网络(Generator) G G G用于生成图片,其输入是一个随机的噪声 z \\boldsymbolz z,通过这个噪声生成图片,记作 G ( z ) G\\left ( \\boldsymbolz \\right ) G(z)
  • 判别网络(Discriminator) D D D用于判别一张图片是否是真实的,对应的,其输入是一整图片 x \\boldsymbolx x,输出 D ( x ) D\\left ( \\boldsymbolx \\right ) D(x)表示的是图片 x \\boldsymbolx x为真实图片的概率

在GAN框架的训练过程中,希望生成网络 G G G生成的图片尽量真实,能够欺骗过判别网络 D D D;而希望判别网络 D D D能够把 G G G生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络 G G G生成的图片能够以假乱真,即对于判别网络 D D D来说,无法判断 G G G生成的网络是不是真实的。

综上,训练好的生成网络 G G G便可以用于生成“以假乱真”的图片。

2. 算法原理

2.1. GAN的框架结构

GAN的框架是由生成网络 G G G和判别网络 D D D这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练,GAN框架由下图所示:

由生成网络 G G G生成一张“Fake image”,判别网络 D D D判断这张图片是否来自真实图片。

2.2. GAN框架的训练过程

在GAN的训练过程中,其最终的目标是使得训练出来的生成模型 G G G生成的图片与真实图片具有相同的分布,其过程可通过下图描述[2]:

假设有一个先验分布 p z ( z ) p_\\boldsymbolz\\left ( \\boldsymbolz \\right ) pz(z),如上图中的unit gaussian,通过采样得到其中的一个样本点 z \\boldsymbolz z。对于真实的图片,事先对于其分布是未知的,即上图中的 p ( x ) p\\left ( \\boldsymbolx \\right ) p(x)未知。为了使得能与真实图片具有相同的分布,通过一个生成模型将先验分布映射到另一个分布,生成模型记为 G ( z ; θ g ) G\\left ( \\boldsymbolz;\\theta _g \\right ) G(z;θg),其中 θ g \\theta _g θg为生成模型的参数,这里的生成模型可以是一个前馈神经网络MLP, θ g \\theta _g θg便为该神经网络的参数。通过多次的采样,便可以刻画出生成的分布 p ^ ( x ) \\hatp\\left ( \\boldsymbolx \\right ) p^(x),此时需要计算其与真实的分布 p ( x ) p\\left ( \\boldsymbolx \\right ) p(x)之间的相关性,即需要一个判别模型来定量表示两个分布之间的相关性,这里可以通过另一个前馈神经网络MLP,判别模型记为 D ( x ; θ d ) D\\left ( \\boldsymbolx;\\theta _d \\right ) D(x;θd),其中 D ( x ; θ d ) D\\left ( \\boldsymbolx;\\theta _d \\right ) D(x;θd)的输出是一个标量,表示的是 x \\boldsymbolx x来自真实的分布,而不是来自于生成模型构造出的分布的概率。

对于这样的一个过程中,有两个模型,分别为生成模型 G ( z ; θ g ) G\\left ( \\boldsymbolz;\\theta _g \\right ) G(z;θg)和判别模型 D ( x ; θ d ) D\\left ( \\boldsymbolx;\\theta _d \\right ) D(x;θd),在GAN中,生成模型和判断模型分别对应了一个神经网络,以下都称为生成网络和判别模型。GAN希望的是对于判别网络,其能够正确判定数据是否来自真实的分布,对于生成网络,其能够尽可能使得生成的数据能够“以假乱真”,使得判别网络分辨不了。这样的训练过程是一个动态的“博弈”过程,通过交替训练,最终使得生成网络 G G G生成的图片能够“以假乱真”,其具体过程如下图所示:

如上图(a)中,黑色的虚线表示的是从真实的分布 p x p_\\boldsymbolx px,绿色的实线表示的是需要训练的生成网络的生成的分布 p g ( G ) p_g\\left ( G \\right ) pg(G),蓝色的虚线表示的是判别网络,最下面的横线 z \\boldsymbolz z表示的是从一个先验分布(如图中是一个均匀分布)采样得到的数据点,中间的横线 x \\boldsymbolx x表示的真实分布,两条横线之间的对应关系表示的是生成网络将先验分布映射成一个生成分布 p g ( G ) p_g\\left ( G \\right ) pg(G)。从图(a)到图(d)表示了一个完整的交替训练过程,首先,如图(a)所示,当通过先验分布采样后的数据经过生成网络 G G G映射后得到了图上绿色的实线代表的分布,此时判别网络 D D D并不能区分数据是否来自真实数据,通过对判别网络的训练,其能够正确地判断生成的数据是否来自真实数据,如图(b)所示;此时更新生成网络 G G G,通过对先验分布重新映射到新的生成分布上,如图(c)中的绿色实线所示。依次交替完成上述步骤,当达到一定迭代的代数后,达到一个平衡状态,此时 p g = p d a t a p_g=p_data pg=pdata,判别网络 D D D将不能区分图片是否来自真实分布,且 D ( x ) = 1 2 D\\left ( \\boldsymbolx \\right )=\\frac12 D(x)=21

2.3. 价值函数

对于GAN框架,其价值函数 V ( G , D ) V\\left ( G,D \\right ) V(G,D)为:

m i n G    m a x D    V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g    D ( x ) ] + E z ∼ p z ( z ) [ l o g    ( 1 − D ( G ( z ) ) ) ] \\undersetGmin\\; \\undersetDmax\\; V\\left ( D,G \\right )=\\mathbbE_\\boldsymbolx\\sim p_data\\left ( \\boldsymbolx \\right )\\left [ log\\; D\\left ( \\boldsymbolx \\right ) \\right ]+\\mathbbE_\\boldsymbolz\\sim p_\\boldsymbolz\\left ( \\boldsymbolz \\right )\\left [ log\\; \\left ( 1-D\\left ( G\\left ( \\boldsymbolz \\right ) \\right ) \\right ) \\right ] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, E x ∼ p d a t a ( x ) [ l o g    D ( x ) ] \\mathbbE_\\boldsymbolx\\sim p_data\\left ( \\boldsymbolx \\right )\\left [ log\\; D\\left ( \\boldsymbolx \\right ) \\right ] Expdata(x)[logD(x)]表示的是 l o g    D ( x ) log\\; D\\left ( \\boldsymbolx \\right ) logD(x)的期望,同理, E z ∼ p z ( z ) [ l o g    ( 1 − D ( G ( z ) ) ) ] \\mathbbE_\\boldsymbolz\\sim p_\\boldsymbolz\\left ( \\boldsymbolz \\right )\\left [ log\\; \\left ( 1-D\\left ( G\\left ( \\boldsymbolz \\right ) \\right ) \\right ) \\right ] Ezpz(z)[log(1D(G(z)))]表示的是 l o g    ( 1 − D ( G ( z ) ) ) log\\; \\left ( 1-D\\left ( G\\left ( \\boldsymbolz \\right ) \\right ) \\right ) log(1关于GAN生成式对抗网络中判别器的输出的问题

GAN (生成对抗网络) 手写数字图片生成

万物皆可 GAN生成对抗网络生成手写数字 Part 1

生成对抗网络GAN

PyTorch实现简单的生成对抗网络GAN

对抗生成网络GAN系列——GANomaly原理及源码解析