GAN实战笔记——第一章GAN简介

Posted 墨戈

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN实战笔记——第一章GAN简介相关的知识,希望对你有一定的参考价值。

GAN简介

一、什么是GAN

GAN是一类由两个同时训练的模型组成的机器学习技术:一个是生成器,训练其生成伪数据:另一个是判别器,训练其从真实数据中识别伪数据。

  • 生成(generative)一词预示着模型的总目标——生成新数据GAN通过学习生成的数据取决于所选择的训练集,例如,如果我们想用GAN合成一幅看起来像达・芬奇作品的画作,就得用达·芬奇的作品作为训练集。
  • 对抗(adversarial)一词则是指构成GAN框架的两个动态博弈、竞争的模型:生成器和判别器生成器的目标是生成与训练集中的真实数据无法区分的伪数据——在刚才的示例中这就意味着能够创作出和达・芬奇画作一样的绘画作品。判别器的目标是能辨别出哪些是来自训练集的真实数据,哪些是来自生成器的伪数据。也就是说,判别器充当着艺术品鉴定专家的角色,评估被认为是达·芬奇画作的作品的真实性。这两个网络不断新地“斗智斗勇”,试图互相欺骗:生成器生成的伪数据越逼真,判别器辨别真伪的能力就要越强
  • 网络(network)一词表示最常用于生成器和判别器的一类机器学习模型:神经网络。依据GAN实现的复杂程度,这些网络包括从最简单的前馈神经网络到卷积神经网络以及更为复杂的变体。

二、GAN是如何工作的

还有一个比喻经常用来形容GAN,假币制造者(生成器)和试图逮捕他的侦探(判别器)——假钞看起来越真实,就需要越好的侦探才能辨别出他们,反之亦然。

用更专业的术语来说,生成器的目标是生成能最大程度有效捕捉训练集特征的样本,以至于生成出的样本与训练数据别无二致。生成器可以看作一个反向的对象识别模型——对象识别算法学习图像中的模式,以期能够识别图像的内容。生成器不是去识别这些模式,而是要学会从头开始学习创建它们,实际上,生成器的输入通常不过是一个随机数向量。

生成器通过从判别器的分类结果中接收反馈来不断学习。判别器的目标是判断一个特定的样本是真的(来自训练集)还是假的(由生成器生成)。因此,每当判别器“上当受骗”将假的图像错判为真实图像时,生成器就会知道自己做得很好:相反,每当判别器正确地将生成器生成的假图像辨别出来时,生成器就会收到需要继续改进的反馈。

判别器也会不断地改善,像其他分类器一样,它会从预测标签与真实标签(真或假)之间的偏差中学习。所以随着生成器能更好地生成更逼真的数据,判别器也能更好地辨别真假数据,两个网络都在同时不断地改进着。

表1.1 生成器和判别器的关键信息

生成器 判别器
输入 一个随机数向量 判别器的输入有两个来源:来自训练集的真实样本和来自生成器的伪样本
输出 尽可能令人信服的伪样本 预测输入样本是真实的概率
目标 生成与训练集中数据别无二致的伪数据 区分来自生成器的伪样本和来自训练集的真实样本

三、GAN的结构

假定我们的目标是教GAN生成逼真的手写数字。GAN的核心结构如下图所示。

让我们看看其中的细节。

(1) 训练数据集——包含真实样本的数据集,是我们希望生成器能以近乎完美的质量去学习模仿的数据。在这个示例中,数据集由手写数字的图像组成。该数据集用作判别器网络的输入(\\(x\\))。

(2) 随机噪声向量——生成器网络的初始输入(z)。此输入是一个由随机数组成的向量,生成器将其用作合成伪样本的起点。

(3) 生成器网络——生成器接收随机数向量(z)作为输入并输出伪样本(x*)。它的目标是生成和训练数据集中的真实样本别无二致的伪样本。(卷积神经网络)

(4) 判别器网络——判别器接收来自训练集的真实样本(x)或生成器生成的伪样本(x*)作为输入。对每个样本,判别器会进行判定并输出其为真实的概率。(反卷积神经网络)

(5) 迭代训练/调优——对于每个判别器的预测,我们会衡量它效果有多好——就像对常规的分类器一样——并用结果反向传播去迭代优化判别器网络和生成网络。

  • 更新判别器的权重和偏置,以最大化其分类的精确度(最大化正确预测的概率:x为真,x*为假)。
  • 更新生成器的权重和偏置,以最大化判别器将x*误判为真的概率。

3.1 GAN的训练

为了了解GAN各组件的用途,我们首先介绍GAN的训练算法,其次演示训练过程,以便我们能够可以清楚的看到实际的框架图。

GAN训练算法
	对于每次训练迭代,执行如下操作。
	(1)训练判别器
		a.从训练集中随机抽取真实样本x。
		b.获取一个新的随机噪声向量z,用生成器网络合成一个伪样本x*。
		c.用判别器网络对x和x*进行分类。
		d.计算分类误差并反向传播总误差以更新判别器的可训练参数,寻求最小化分类误差。
	(2)训练生成器
		a.获取一个新的随机噪声向量z,用生成器网络合成一个伪样本x*。
		b.用判别器网络对x*进行分类。
		c.计算分类误差并反向传播以更新生成器的可训练参数,寻求最大化判别器误差。
	结束

GAN训练过程可视化

GAN的训练算法如下图所示,其中的字母表示GAN训练算法中的步骤。

子程序图示说明

(1)训练判别器

​ a. 从训练集中随机抽取真实样本x。

​ b. 获取一个新的随机噪声向量z,用生成器网结合成一个伪样本x*。

​ c. 用判别器网络对x和x*进行分类。

​ d. 计算分类误差并反向传播总误差以更新判别器的权重和偏置,寻求最小化分类误差。

(2)训练生成器

​ a. 获取一个新的随机噪声向量z,用生成器网络合成一个伪样本x*。

​ b.用判别器网络对x*进行分类。

​ c.计算分类误差并反向传播以更新生成器的可训练参数,寻求最大化判别器误差。

3.2 达到平衡

对于一般的神经网络,我们通常有一个明确的目标去实现以及用来衡量效果。例如,当训练一个分类器时,我们度量在训练集和验证集上的分类误差,一旦发现验证集开始变坏,就停止进程(为了避免过拟合)。在GAN结构中,判别器网络和生成器网络有两个互为竞争对手的目标:一个网络越好,另一个就越差。那么我们如何决定何时停止进程呢?

这其实是一个零和博弈问题,即一方的收益等于另一方的损失。当一方提高一定程度时,另一方会恶化同样的程度。零和博弈都有一个纳什均衡点,那就是任何一方无论怎么努力都不能改善他们的处境或结果。

当满足以下条件时,GAN达到纳什均衡点

(1)生成器生成的伪样本与训练集中的真实数据别无二致。

(2)判别器所能做的只是随机猜测一个特定的样本是真的还是假的(也就是说,猜测一个示例为真的概率是50%)。

让我们来解释为何会出现这种情况。当每一个伪样本(x*)与来自训练集的真实样本无法区分时,判别器用任何手段都无法区分它们。因为判别器接收到的样本有一半是真的,半是假的,所以它所能做的最有用的事情就是抛硬币,以50%的概率把每个样本分为真和假。

同样,生成器也处于这样一个点上,它不能从进一步的调优中获得任何提高了。因为生成器生成的样本早已和真实样本无法区分了,以至于对随机噪声向量(z)转换为伪样本(x)的过程做出哪怕一丁点儿改变,也可能给判别器提供从真实样本中辨别出伪样本的机会,从而使生成器变得更糟。

当达到纳什均衡时,GAN就被认为是收敛的。这是一个棘手的问题,在实践中,由于在非凸博弈中实现收敛所涉及的巨大复杂性,几乎不可能达到GAN的纳什均衡。实际上,GAN的收敛仍是GAN研究中最重要的开放性问题之一。

小结

  1. GAN是一种利用两个神经网络之间的动态竞争来合成真实数据样本的深度学习技术,例如能合成具有照片级真实感的虚假图像。构成一个完整GAN的两个网络如下:
    • 生成器,其目标是通过生成与训练数据集别无二致的数据来欺骗判别器;
    • 判别器,其目标是正确区分来自训练数据集的真实数据和由生成器生成的伪数据。

对抗生成网络GAN系列——GANomaly原理及源码解析

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例   对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例   对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战   对抗生成网络GAN系列——EGBAD原理及缺陷检测实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

文章目录

对抗生成网络GAN系列——GANomaly原理及源码解析

写在前面

​ 在前面,我已经介绍过好几篇有关GAN的文章,链接如下:

  这篇文章我将来为大家介绍GANomaly,论文名为:Semi-Supervised Anomaly Detection via Adversarial Training。这篇文章同样是实现缺陷检测的,因此在阅读本文之前建议你对使用GAN网络实现缺陷检测有一定的了解,可以参考上文链接中的[4]和[5]。

  准备好了吗,嘟嘟嘟,开始发车。🚖🚖🚖

 

GANomaly原理解析


【阅读此部分前建议对GAN的原理及GAN在缺陷检测上的应用有所了解,详情点击写在前面中的链接查看,本篇文章我不会再介绍GAN的一些先验知识。】


GANomaly结构

​ 这部分为大家介绍GANomaly的原理,其实我们一起来看下图就足够了:

图1 GANomaly结构图

  我们还是先来对上图中的结构做一些解释。从直观的颜色上来看,我们可以分成两类,一类是红色的Encoder结构,一类是蓝色的Decoder结构。Encoder主要就是降维的作用啦,如将一张张图片数据压缩成一个个潜在向量;相反,Decoder就是升维的作用,如将一个个潜在向量重建成一张张图片。按照论文描述的结构来分,可以分成三个子结构,分别为生成器网络G,编码器网络E和判别器网络D。下面分别来介绍介绍这三个子结构:

  • 生成器网络G
      生成器网络G由两个部分组成,分别为编码器 G E ( x ) ) G_E(x)) GE(x))和解码器 G D ( z ) G_D(z) GD(z),其实这就是一个自动编码器结构,主要用来学习输入x的数据分布并重建图像 x ^ \\hat x x^。我们一个个来看,先看 G E ( x ) G_E(x) GE(x)结构,假设我们的输入x维度为 R C × H × W \\mathbbR^C×H×W RC×H×W,经过 G E ( x ) G_E(x) GE(x)结构后,变成一个向量 z z z,其维度为 R d \\mathbbR^d Rd G E ( x ) G_E(x) GE(x)具体结构很简单啦,这里就不详细介绍了。我会在源码解析部分给出,大家肯定一看就会。】接着我们来看 G D ( z ) G_D(z) GD(z)结构,它会将刚刚得到的向量z上采样成 x ^ \\hat x x^ x ^ \\hat x x^的维度和 x x x一致,都为 R C × H × W \\mathbbR^C×H×W RC×H×W。关于 G D ( Z ) G_D(Z) GD(Z)结构也很简单,其主要用到了转置卷积,对于转置卷积不了解的可以看博客[2]了解详情。生成器网络G就为大家介绍完了,是不是发现很简单呢。总结下来就两步,第一步让输入x通过 G E ( x ) G_E(x) GE(x)得到z,第二步让z通过 G D ( Z ) G_D(Z) GD(Z)变成 x ^ \\hat x x^。这两步也可以用一步表示,即 x ^ = G ( x ) \\hat x=G(x) x^=G(x)

      思来想去我还是想在这里给大家抛出一个问题,我们传统的GAN是怎么通过生成器来构建假图像的呢?和GANomaly有区别吗?其实这个问题的答案很简单,大家都稍稍思考一下,我就不给答案了,不明白的评论区见吧!!!🥂🥂🥂

  • 编码器网络E

    ​   编码器网络E的作用是将生成器得到的 x ^ \\hat x x^压缩成一个向量 z ^ \\hat z z^,是不是发现和生成器网络中的 G E ( x ) G_E(x) GE(x)很像呢,其实呀,它俩的结构就是完全一样的,生成的 z ^ \\hat z z^ x ^ \\hat x x^ 的维度一致,这是方便后面的损失比较。

 

  • 判别器网络D

    ​   判别器网络D和我们之前介绍DCGAN时的结构是一样的,都是将真实数据 x x x和生成数据 x ^ \\hat x x^输入网络,然后得出一个分数。

 

GANomaly损失函数

  GANomaly的损失函数分为两部分,第一部分是生成器损失,第二部分为判别器损失,下面我们分别来进行介绍:

  • 生成器损失函数

    ​ 生成器损失函数又由三个部分组成,分别如下:

    • Adversari Loss

      我还是直接上公式吧,如下:

      L a d v = E x ∼ p x ∣ ∣ f ( x ) − E x ∼ p x f ( G ( x ) ) ∣ ∣ 2 L_adv=E_x \\sim px||f(x)-E_x \\sim pxf(G(x))||_2 Ladv=Expxf(x)Expxf(G(x))2

      这个公式对应图一中的 L a d v = ∣ ∣ f ( x ) − f ( x ^ ) ∣ ∣ 2 L_adv=||f(x)-f(\\hat x)||_2 Ladv=f(x)f(x^)2🍵🍵🍵这个损失函数应该很好理解,在前面介绍的GAN网络都有提及, f ( ∗ ) f(*) f()表示判别器网络某个中间层的输出。这个损失函数的作用就是让两张图像 x 和 x ^ x和\\hat x xx^尽可能接近,也就是让生成器生成的图片更加逼真。


    • Contextual Loss

      同样的,直接来上公式,如下:

      L c o n = E x ∼ p x ∣ ∣ x − G ( x ) ∣ ∣ 1 L_con=E_x \\sim px||x-G(x)||_1 Lcon=ExpxxG(x)1

      这个公式对应图一中的 L c o n = ∣ ∣ x − x ^ ∣ ∣ 1 L_con=||x-\\hat x||_1 Lcon=xx^1🍵🍵🍵这个函数其实也是要让两张图像 x 和 x ^ x和\\hat x xx^尽可能接近。至于这里为什么用的是L1范数而不是L2范数,作者在论文中说这里使用L1范数的效果要比使用L2范数的效果好,这属于实验得到的结论,大家也不用过于纠结。

    • Encoder Loss

      话不多说,上公式,如下:

      L e n c = E x ∼ p x ∣ ∣ G E ( x ) − E ( G ( x ) ) ∣ ∣ 2 L_enc=E_x \\sim px||G_E(x)-E(G(x))||_2 Lenc=ExpxGE(x)E(G(x))2

      这个公式对应图一中的 L e n c = ∣ ∣ z − z ^ ∣ ∣ 2 L_enc=||z-\\hat z||_2 Lenc=zz^2🍵🍵🍵这里的损失函数在我看来主要作用就是让我们在推理过程中的效果更好,这里就像AnoGAN中不断搜索最优的那个z的作用。


      如果大家这里读过cycleGAN的论文的话,可能会觉得这个损失函数有点类似cycleGAN中的循环一致性损失。我觉得这篇文章的思想可能借鉴了cycleGAN中的思想,感兴趣的可以去阅读一下,非常有意思的一篇文章!!!🥃🥃🥃



    生成器总的损失是上述三种损失的加权和,如下:
    L = w a d v L a d v + w c o n L c o n + w e n c L e n c L=w_advL_adv+w_conL_con+w_encL_enc Pytorch入门实战:基于GAN生成简单的动漫人物头像

    第一节2:GAN经典案例之MNIST手写数字生成

    美团云Tensorflow生成对抗网络(Generative Adversarial Networks)实战案例

    Pytorch搭建基本的GAN模型及训练过程

    GAN实战之Pytorch 使用pix2pixGAN生成建筑物图片

    GAN 转