万物皆可 GAN生成对抗网络生成手写数字 Part 1
Posted 我是小白呀
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了万物皆可 GAN生成对抗网络生成手写数字 Part 1相关的知识,希望对你有一定的参考价值。
概述
GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.
GAN 网络结构
生成器 (Generator) 输入一个向量, 输出手写数字大小的像素图像.
判别器 (Discriminator) 输入图片, 判断图片是来自数据集还是来自生成器的, 输出标签 (Real / Fake)
GAN 训练流程
第一阶段:
- 固定判别器, 训练生成器: 使得生成器的技能不断提升, 骗过判别器
第二阶段:
- 固定生成器, 训练判别器: 使得判别器的技能不断提升, 生成器无法骗过判别器
然后:
- 循环第一阶段和第二阶段, 使得生成器和判别器都越来越强
模型详解
生成器
class Generator(nn.Module):
"""生成器"""
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
"""
block
:param in_feat: 输入的特征维度
:param out_feat: 输出的特征维度
:param normalize: 归一化
:return: block
"""
layers = [nn.Linear(in_feat, out_feat)]
# 归一化
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
# 激活
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
# [b, 100] => [b, 128]
*block(latent_dim, 128, normalize=False),
# [b, 128] => [b, 256]
*block(128, 256),
# [b, 256] => [b, 512]
*block(256, 512),
# [b, 512] => [b, 1024]
*block(512, 1024),
# [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
nn.Linear(1024, int(np.prod(img_shape))),
# 激活
nn.Tanh()
)
def forward(self, z, img_shape):
# [b, 100] => [b, 784]
img = self.model(z)
# [b, 784] => [b, 1, 28, 28]
img = img.view(img.size(0), *img_shape)
# 返回生成的图片
return img
网络结构:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 128] 12,928
LeakyReLU-2 [-1, 128] 0
Linear-3 [-1, 256] 33,024
BatchNorm1d-4 [-1, 256] 512
LeakyReLU-5 [-1, 256] 0
Linear-6 [-1, 512] 131,584
BatchNorm1d-7 [-1, 512] 1,024
LeakyReLU-8 [-1, 512] 0
Linear-9 [-1, 1024] 525,312
BatchNorm1d-10 [-1, 1024] 2,048
LeakyReLU-11 [-1, 1024] 0
Linear-12 [-1, 784] 803,600
Tanh-13 [-1, 784] 0
================================================================
Total params: 1,510,032
Trainable params: 1,510,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 5.76
Estimated Total Size (MB): 5.82
----------------------------------------------------------------
判别器
class Discriminator(nn.Module):
"""判断器"""
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# 就是个线性回归
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
# 压平
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
网络结构:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 512] 401,920
LeakyReLU-2 [-1, 512] 0
Linear-3 [-1, 256] 131,328
LeakyReLU-4 [-1, 256] 0
Linear-5 [-1, 1] 257
Sigmoid-6 [-1, 1] 0
================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05
以上是关于万物皆可 GAN生成对抗网络生成手写数字 Part 1的主要内容,如果未能解决你的问题,请参考以下文章
生成对抗网络(GAN)详细介绍及数字手写体生成应用仿真(附代码)