万物皆可 GAN生成对抗网络生成手写数字 Part 2
Posted 我是小白呀
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了万物皆可 GAN生成对抗网络生成手写数字 Part 2相关的知识,希望对你有一定的参考价值。
概述
GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.
完整代码
模型
model.py:
import numpy as np
import torch.nn as nn
class Generator(nn.Module):
"""生成器"""
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
"""
block
:param in_feat: 输入的特征维度
:param out_feat: 输出的特征维度
:param normalize: 归一化
:return: block
"""
layers = [nn.Linear(in_feat, out_feat)]
# 归一化
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
# 激活
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
# [b, 100] => [b, 128]
*block(latent_dim, 128, normalize=False),
# [b, 128] => [b, 256]
*block(128, 256),
# [b, 256] => [b, 512]
*block(256, 512),
# [b, 512] => [b, 1024]
*block(512, 1024),
# [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
nn.Linear(1024, int(np.prod(img_shape))),
# 激活
nn.Tanh()
)
def forward(self, z, img_shape):
# [b, 100] => [b, 784]
img = self.model(z)
# [b, 784] => [b, 1, 28, 28]
img = img.view(img.size(0), *img_shape)
# 返回生成的图片
return img
class Discriminator(nn.Module):
"""判断器"""
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# [b, 1, 28, 28] => [b, 784] => [b, 512]
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
# 压平
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
主函数
main.py
import argparse
import time
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.utils import save_image
from model import Generator, Discriminator
def get_data(img_size, batch_size):
"""获取数据"""
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=batch_size,
shuffle=True,
)
return dataloader
# ---------- Training ----------
def train_model(n_epochs, dataloader, image_shape, image_path):
"""
训练模型
:param n_epochs: 迭代次数
:param dataloader: 数据集
:param image_shape: 图片形状
:param image_path: 图片保存路径
:return:
"""
# 迭代n_epochs次
for epoch in range(n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
# 生成[b, 1]形状的全1数组
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
# 生成[b, 1]形状的全0数组
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# Train Generator: 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成噪声, 作为输入 [b, 100]
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# 生成图像 [b, 1, 28, 28]
gen_imgs = generator(z, image_shape)
# Loss measures generator's ability to fool the discriminator
# 计算生成器损失
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator: 训练判别器
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25],
"images/{}/{}.png".format(image_path, batches_done), nrow=5,
normalize=True)
def option():
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=16384, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
return opt
if __name__ == "__main__":
# 图片保存路径
image_path = time.strftime("%Y_%m_%d_%H_%M_%S")
os.makedirs("images", exist_ok=True)
os.makedirs("images/{}".format(image_path))
# 超参数
opt = option()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
# Loss function
adversarial_loss = torch.nn.BCELoss()
# 生成器, 判别器初始化
generator = Generator(latent_dim=opt.latent_dim, img_shape=img_shape)
discriminator = Discriminator(img_shape=img_shape)
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
dataloader = get_data(
img_size=opt.img_size,
batch_size=opt.batch_size
)
train_model(
n_epochs=opt.n_epochs,
dataloader=dataloader,
image_shape=img_shape,
image_path=image_path
)
输出结果
Namespace(b1=0.5, b2=0.999, batch_size=30000, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=16, n_epochs=300, sample_interval=20)
[Epoch 0/300] [Batch 0/2] [D loss: 0.713072] [G loss: 0.732759]
[Epoch 0/300] [Batch 1/2] [D loss: 0.625757] [G loss: 0.729401]
[Epoch 1/300] [Batch 0/2] [D loss: 0.557747] [G loss: 0.726514]
[Epoch 1/300] [Batch 1/2] [D loss: 0.501802] [G loss: 0.723442]
[Epoch 2/300] [Batch 0/2] [D loss: 0.456023] [G loss: 0.719633]
[Epoch 2/300] [Batch 1/2] [D loss: 0.421082] [G loss: 0.714690]
[Epoch 3/300] [Batch 0/2] [D loss: 0.397120] [G loss: 0.708184]
[Epoch 3/300] [Batch 1/2] [D loss: 0.382698] [G loss: 0.699813]
[Epoch 4/300] [Batch 0/2] [D loss: 0.376455] [G loss: 0.689178]
[Epoch 4/300] [Batch 1/2] [D loss: 0.375903] [G loss: 0.675870]
[Epoch 5/300] [Batch 0/2] [D loss: 0.379960] [G loss: 0.660027]
[Epoch 5/300] [Batch 1/2] [D loss: 0.387069] [G loss: 0.642543]
[Epoch 6/300] [Batch 0/2] [D loss: 0.396221] [G loss: 0.624336]
[Epoch 6/300] [Batch 1/2] [D loss: 0.406274] [G loss: 0.606929]
[Epoch 7/300] [Batch 0/2] [D loss: 0.416247] [G loss: 0.591856]
[Epoch 7/300] [Batch 1/2] [D loss: 0.424656] [G loss: 0.581296]
[Epoch 8/300] [Batch 0/2] [D loss: 0.431091] [G loss: 0.575895]
[Epoch 8/300] [Batch 1/2] [D loss: 0.435463] [G loss: 0.576295]
[Epoch 9/300] [Batch 0/2] [D loss: 0.438092] [G loss: 0.582505]
[Epoch 9/300] [Batch 1/2] [D loss: 0.439382] [G loss: 0.593778]
[Epoch 10/300] [Batch 0/2] [D loss: 0.440484] [G loss: 0.607302]
[Epoch 10/300] [Batch 1/2] [D loss: 0.441629] [G loss: 0.619709]
[Epoch 11/300] [Batch 0/2] [D loss: 0.444755] [G loss: 0.625674]
[Epoch 11/300] [Batch 1/2] [D loss: 0.451140] [G loss: 0.625349]
[Epoch 12/300] [Batch 0/2] [D loss: 0.461600] [G loss: 0.623004]
[Epoch 12/300] [Batch 1/2] [D loss: 0.475907] [G loss: 0.616745]
[Epoch 13/300] [Batch 0/2] [D loss: 0.492632] [G loss: 0.603465]
[Epoch 13/300] [Batch 1/2] [D loss: 0.509568] [G loss: 0.589097]
[Epoch 14/300] [Batch 0/2] [D loss: 0.524420] [G loss: 0.579904]
[Epoch 14/300] [Batch 1/2] [D loss: 0.532494] [G loss: 0.580591]
[Epoch 15/300] [Batch 0/2] [D loss: 0.533520] [G loss: 0.582400]
[Epoch 15/300] [Batch 1/2] [D loss: 0.528302] [G loss: 0.605238]
[Epoch 16/300] [Batch 0/2] [D loss: 0.520255] [G loss: 0.617953]
[Epoch 16/300] [Batch 1/2] [D loss: 0.512036] [G loss: 0.637511]
[Epoch 17/300] [Batch 0/2] [D loss: 0.506406] [G loss: 0.650202]
[Epoch 17/300] [Batch 1/2] [D loss: 0.507003] [G loss: 0.660099]
[Epoch 18/300] [Batch 0/2] [D loss: 0.515254] [G loss: 0.652070]
[Epoch 18/300] [Batch 1/2] [D loss: 0.531098] [G loss: 0.644829]
[Epoch 19/300] [Batch 0/2] [D loss: 0.546756] [G loss: 0.635194]
[Epoch 19/300] [Batch 1/2] [D loss: 0.557034] [G loss: 0.625947]
[Epoch 20/300] [Batch 0/2] [D loss: 0.562781] [G loss: 0.631091]
... ...
[Epoch 280/300] [Batch 1/2] [D loss: 0.593827] [G loss: 0.486869]
[Epoch 281/300] [Batch 0/2] [D loss: 0.488985] [G loss: 1.372436]
[Epoch 281/300] [Batch 1/2] [D loss: 0.478138] [G loss: 0.782779]
[Epoch 282/300] [Batch 0/2] [D loss: 0.445211] [G loss: 1.119433]
[Epoch 282/300] [Batch 1/2] [D loss: 0.450222] [G loss: 0.956143]
[Epoch 283/300] [Batch 0/2] [D loss: 0.459510] [G loss: 1.030325]
[Epoch 283/300] [Batch 1/2] [D loss: 0.481445] [G loss: 0.955175]
[Epoch 284/300] [Batch 0/2] [D loss: 0.494930] [G loss: 0.967074]
[Epoch 284/300] [Batch 1/2] [D loss: 0.513425] [G loss: 0.888255]
[Epoch 285/300] [Batch 0/2] [D loss: 0.523844] [G loss: 0.933630]
[Epoch 285/300] [Batch 1/2] [D loss: 0.528002] [G loss: 0.829575]
[Epoch 286/300] [Batch 0/2] [D loss: 0.510690] [G loss: 1.025084]
[Epoch 286/300] [Batch 1/2] [D loss: 0.511791] [G loss: 0.752756]
[Epoch 287/300] [Batch 0/2] [D loss: 0.491313] [G loss: 1.258785]
[Epoch 287/300] [Batch 1/2] [D loss: 0.539705] [G loss: 0.594626]
[Epoch 288/300] [Batch 0/2] [D loss: 0.523798] [G loss: 1.575793]
[Epoch 288/300] [Batch 1/2] [D loss: 0.595133] [G loss: 0.471295]
[Epoch 289/300] [Batch 0/2] [D loss: 0.477448] [G loss: 1.566631]
[Epoch 289/300] [Batch 1/2] [D loss: 0.477544] [G loss: 0.706210]
[Epoch 290/300] [Batch 0/2] [D loss: 0.429852] [G loss: 1.241274]
[Epoch 290/300] [Batch 1/2] [D loss: 0.437449] [G loss: 0.915428]
[Epoch 291/300] [Batch 0/2] [D loss: 0.434945] [G loss: 1.074615]
[Epoch 291/300] [Batch 1/2] [D loss: 0.449143] [G loss: 0.958715]
[Epoch 292/300] [Batch 0/2] [D loss: 0.454942] [G loss: 1.042261]
[Epoch 292/300] [Batch 1/2] [D loss: 0.469878] [G loss: 0.926095]
[Epoch 293/300] [Batch 0/2] [D loss: 0.466286] [G loss: 1.058231]
[Epoch 293/300] [Batch 1/2] [D loss: 0.471976] [G loss: 0.851719]
[Epoch 294/300] [Batch 0/2] [D loss: 0.459671] [G loss: 1.190166]
[Epoch 294/300] [Batch 1/2] [D loss: 0.488745] [G loss: 0.692813]
[Epoch 295/300] [Batch 0/2] [D loss: 0.478135] [G loss: 1.635507]
[Epoch 295/300] [Batch 1/2] [D loss: 0.606305] [G loss: 0.428279]
[Epoch 296/300] [Batch 0/2] [D loss: 0.525440] [G loss: 2.067884]
[Epoch 296/300] [Batch 1/2] [D loss: 0.568282] [G loss: 0.477078]
[Epoch 297/300] [Batch 0/2] [D loss: 0.402702] [G loss: 1.508152]
[Epoch 297/300] [Batch 1/2] [D loss: 0.412127] [G loss: 0.929811]
[Epoch 298/300] [Batch 0/2] [D loss: 0.439405] [G loss: 0.930188]
[Epoch 298/300] [Batch 1/2] [D loss: 0.471616] [G loss: 1.025511]
[Epoch 299/300] [Batch 0/2] [D loss: 0.524946] [G loss: 0.700306]
[Epoch 299/300] [Batch 1/2] [D loss: 0.537755] [G loss: 1.130932]
生成的图片
0.png:
100.png
200.png
300.png
400.png
500.png
580.png:
以上是关于万物皆可 GAN生成对抗网络生成手写数字 Part 2的主要内容,如果未能解决你的问题,请参考以下文章
生成对抗网络(GAN)详细介绍及数字手写体生成应用仿真(附代码)