GAN的基本概念
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN的基本概念相关的知识,希望对你有一定的参考价值。
参考技术A 1. 组成1) 判别算法discriminative algorithms
判别算法根据输入数据的feature进行分类label; 比如这封邮件是不是垃圾邮件
2) 生成算法 generative algorithms
生成算法根据特定的label预测feature
一个卷积神经网络
生成新数据
4) 判别器 discriminator
一个卷积神经网络
评估生成数据的真实性
2. 工作过程
1) 简述
生成模型生成一些图片->判别模型学习区分生成的图片和真实图片->生成模型根据判别模型改进自己,生成新的图片->····
2) 过程
a) 在以下两个Minibatch应用梯度下降等优化算法
训练样本
生成样本
b) 优化
3) 结果
生成模型与判别模型无法提高自己——即判别模型无法判断一张图片是生成出来的还是真实的而结束
二、 数学描述
1. 生成模型 g(z)
1) 概念
输入z是一个随机噪声
g的输出就是一张图片
2. 判别模型 D(x)
1) 概念
输入X是从模型中的抽取的样本
D(x)的输出是0-1范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大
2) 公式
解释: D=1(真实图片),则P_data=1, P_model=0
D=0(假的图片),则 P_data=0
3. 优化方法
o 交互迭代
固定g,优化D,一段时间后,固定D再优化g,直到过程收敛。
4. 代价函数
1) 判别器的代价函数J(D)
迄今为止,所有的判别其代价函数都一样,不一样的是生成器的代价函数。
这是一个标准的cross-entropy cost
真实数据的label=1,生成器生成的数据的label=0
目标是最小化判别器的代价函数
2) Minimax 博弈/零和博弈
3) Non-saturating heuristic启发式,非饱和博弈
启发式驱动是说按照实际操作改变策略,而不是遵循理论。
为了防止生成器的梯度消失,把生成器的cost变成:
4) Maximum likelihood最大似然博弈
5) Divergence的选择
6) Generator 的Cost function的比较
结论:使用non-saturating
三、 DCGAN
1. 引入——CNN
1) 概念
(convolutional neural network,卷积神经网络)。深度神经网络的一种,可以通过卷积层(convolutional layer)提取不同层级的信息
2) 输入
图片
3) 输出
图片
抽象表达(纹理/形状etc)
2. 简介
o 人们为图像生成设计了一种类似反卷积的结构:Deep convolutional NN for GAN(DCGAN)
o 是GAN的使用场景——在图片中的生成模型DCGAN。
3. 输入
o 随机噪声向量(高斯噪声etc)
4. 处理流程
o 输入通过与CNN类似但是相反的结构,将输入放大成二维数据。
o 通过生成模型+判别模型
o 生成图片
四、 训练GAN的tips
1. 用label训练
2. One-sided label smoothing
防止判别器的极端的可信分类(某一类的可能性为1)
帮助判别器抵抗生成器的攻击
3. 注意Batch norm的使用
1) Feature norm的定义
subtract mean, divide by standard deviation
2) 注意
会导致batch之间相互影响,应该使用virtual batch norm
4. 平衡G和D
D应该获胜,应该多锻炼D,D应该更加复杂
https://deeplearning4j.org/generative-adversarial-network#tips-in-training-a-gan
1. discriminator 与generator,固定一个,训练另一个
2. 开始训练generator之前预先训练discriminator
3. GAN训练很费时间,建议用GPU不要用CPU
五、 GAN与其他神经网络的比较
1. Auto encoder/decoder
可以压缩原始数据
2. VAE(Variational变化的 auto encoder)
增加了额外的限制
可以压缩原始数据
可以生成数据(生成效果没有GAN好)
六、 GAN的有趣应用
https://deeplearning4j.org/generative-adversarial-network#unclassified-papers--resources
七、 评估
1. inception scores
It's particularly interesting because it seems to be a very reliable way to "quantify realism" in GANs,
八、 前沿研究
1. 解决Non-convergence
2. 如何评估生成模型
3. 离散的输出
4. 半监督化学习
5. 。。。
GAN的学习记录
最近看了一下神经网络和卷积神经网络(CNN)的基础概念,然后开始看生成对抗网络(GAN)的基础知识,之后会自己写一下代码,用GAN对数据集进行训练。
一、12月的计划:
(1)先看懂GAN的基础理论
(2)找一些代码,想办法把轴承的数据集放到GAN里面训练
(3)看老师给的论文,以及自己去看GAN的变种论文、gihub的开源代码,并进行复现
二、学习方法:
1.b站有李宏毅的视频,讲的比较清晰,看完GAN的4讲内容
2.这个链接里,有人整理好了各种GAN论文的原文地址,之后的计划就是三个月内都读一遍?
3.一些参考:
生成对抗网络-改进方法|深度学习(李宏毅)(二十四)笔记
三、码住博客
-
GAN的原理,及其他经典变种,例如DCGAN、WGAN和WGAN-GP、Conditional GAN的介绍。不过这篇写的比较浅显,只用来快速了解GAN使用。下面的是看这篇博文时去查的东西
(1)池化层、FC层、BN层的作用。BN层训练的步骤。
(2)内部协变量偏移(Internal Covariate Shift)和批归一化(Batch Normalization)
(3)激活函数Relu 及 leakyRelu
(4)转置卷积
(5)stride卷积 -
我在github找到了一个各种gan代码的开源,地址为:PyTorch-GAN-master
在gan里跑了一下CWRU的轴承数据集,发现 dloss和gloss十分不稳定,且gloss根本不收敛,后来想起来,轴承数据是有标签数据的,不能直接拿来训练,需要用CGAN方法。
下面是直接训练时的LOSS
-
CGAN 简单来说,就是讲标签信息和噪声z一起输入生成器,然后判别器判断对错来更新。
(1)简单用mnist数据集的效果来观察GAN和CGAN,不过每次200个epochs都要训练3个小时左右,好慢。。。
这是GAN方法的结果:
这是CGAN方法的结果,因为懒得训练了,就训了50个epochs,不过前50epochs的效果还是明显好于GAN的:
从生成图片的结果可以看出,基本上GAN生成的最好的图片,也不如CGAN后几十张生成的图片的质量。
(2)注释代码
1.通俗讲解pytorch中nn.Embedding原理及使用
nn.Embedding()参数的理解
(3)CGAN在轴承数据上的复现。。
以上是关于GAN的基本概念的主要内容,如果未能解决你的问题,请参考以下文章
李宏毅2021机器学习深度学习6-1 生成式对抗网络GAN1 基本概念介绍