万物皆可 GAN生成对抗网络生成手写数字 Part 2

Posted 我是小白呀

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了万物皆可 GAN生成对抗网络生成手写数字 Part 2相关的知识,希望对你有一定的参考价值。

【万物皆可 GAN】生成对抗网络生成手写数字 Part 2

概述

GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.

在这里插入图片描述

完整代码

模型

model.py:

import numpy as np
import torch.nn as nn


class Generator(nn.Module):
    """生成器"""

    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            """
            block
            :param in_feat: 输入的特征维度
            :param out_feat: 输出的特征维度
            :param normalize: 归一化
            :return: block
            """
            layers = [nn.Linear(in_feat, out_feat)]

            # 归一化
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))

            # 激活
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # [b, 100] => [b, 128]
            *block(latent_dim, 128, normalize=False),
            # [b, 128] => [b, 256]
            *block(128, 256),
            # [b, 256] => [b, 512]
            *block(256, 512),
            # [b, 512] => [b, 1024]
            *block(512, 1024),
            # [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
            nn.Linear(1024, int(np.prod(img_shape))),
            # 激活
            nn.Tanh()
        )

    def forward(self, z, img_shape):
        # [b, 100] => [b, 784]
        img = self.model(z)
        # [b, 784] => [b, 1, 28, 28]
        img = img.view(img.size(0), *img_shape)

        # 返回生成的图片
        return img


class Discriminator(nn.Module):
    """判断器"""

    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # [b, 1, 28, 28] => [b, 784] => [b, 512]
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        # 压平
        img_flat = img.view(img.size(0), -1)

        validity = self.model(img_flat)

        return validity

主函数

main.py

import argparse
import time
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torchvision.transforms as transforms
from torchvision.utils import save_image

from model import Generator, Discriminator


def get_data(img_size, batch_size):
    """获取数据"""

    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    return dataloader


# ---------- Training ----------
def train_model(n_epochs, dataloader, image_shape, image_path):
    """
    训练模型
    :param n_epochs: 迭代次数
    :param dataloader: 数据集
    :param image_shape: 图片形状
    :param image_path: 图片保存路径
    :return:
    """

    # 迭代n_epochs次
    for epoch in range(n_epochs):
        for i, (imgs, _) in enumerate(dataloader):

            # Adversarial ground truths
            # 生成[b, 1]形状的全1数组
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            # 生成[b, 1]形状的全0数组
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # -----------------
            #  Train Generator: 训练生成器
            # -----------------

            optimizer_G.zero_grad()

            # 生成噪声, 作为输入 [b, 100]
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

            # 生成图像 [b, 1, 28, 28]
            gen_imgs = generator(z, image_shape)

            # Loss measures generator's ability to fool the discriminator
            # 计算生成器损失
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator: 训练判别器
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           "images/{}/{}.png".format(image_path, batches_done), nrow=5,
                           normalize=True)


def option():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=16384, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
    opt = parser.parse_args()
    return opt


if __name__ == "__main__":
    # 图片保存路径
    image_path = time.strftime("%Y_%m_%d_%H_%M_%S")
    os.makedirs("images", exist_ok=True)
    os.makedirs("images/{}".format(image_path))

    # 超参数
    opt = option()
    print(opt)

    img_shape = (opt.channels, opt.img_size, opt.img_size)

    cuda = True if torch.cuda.is_available() else False

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # 生成器, 判别器初始化
    generator = Generator(latent_dim=opt.latent_dim, img_shape=img_shape)
    discriminator = Discriminator(img_shape=img_shape)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    dataloader = get_data(
        img_size=opt.img_size,
        batch_size=opt.batch_size
    )

    train_model(
        n_epochs=opt.n_epochs,
        dataloader=dataloader,
        image_shape=img_shape,
        image_path=image_path
    )

输出结果

Namespace(b1=0.5, b2=0.999, batch_size=30000, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=16, n_epochs=300, sample_interval=20)
[Epoch 0/300] [Batch 0/2] [D loss: 0.713072] [G loss: 0.732759]
[Epoch 0/300] [Batch 1/2] [D loss: 0.625757] [G loss: 0.729401]
[Epoch 1/300] [Batch 0/2] [D loss: 0.557747] [G loss: 0.726514]
[Epoch 1/300] [Batch 1/2] [D loss: 0.501802] [G loss: 0.723442]
[Epoch 2/300] [Batch 0/2] [D loss: 0.456023] [G loss: 0.719633]
[Epoch 2/300] [Batch 1/2] [D loss: 0.421082] [G loss: 0.714690]
[Epoch 3/300] [Batch 0/2] [D loss: 0.397120] [G loss: 0.708184]
[Epoch 3/300] [Batch 1/2] [D loss: 0.382698] [G loss: 0.699813]
[Epoch 4/300] [Batch 0/2] [D loss: 0.376455] [G loss: 0.689178]
[Epoch 4/300] [Batch 1/2] [D loss: 0.375903] [G loss: 0.675870]
[Epoch 5/300] [Batch 0/2] [D loss: 0.379960] [G loss: 0.660027]
[Epoch 5/300] [Batch 1/2] [D loss: 0.387069] [G loss: 0.642543]
[Epoch 6/300] [Batch 0/2] [D loss: 0.396221] [G loss: 0.624336]
[Epoch 6/300] [Batch 1/2] [D loss: 0.406274] [G loss: 0.606929]
[Epoch 7/300] [Batch 0/2] [D loss: 0.416247] [G loss: 0.591856]
[Epoch 7/300] [Batch 1/2] [D loss: 0.424656] [G loss: 0.581296]
[Epoch 8/300] [Batch 0/2] [D loss: 0.431091] [G loss: 0.575895]
[Epoch 8/300] [Batch 1/2] [D loss: 0.435463] [G loss: 0.576295]
[Epoch 9/300] [Batch 0/2] [D loss: 0.438092] [G loss: 0.582505]
[Epoch 9/300] [Batch 1/2] [D loss: 0.439382] [G loss: 0.593778]
[Epoch 10/300] [Batch 0/2] [D loss: 0.440484] [G loss: 0.607302]
[Epoch 10/300] [Batch 1/2] [D loss: 0.441629] [G loss: 0.619709]
[Epoch 11/300] [Batch 0/2] [D loss: 0.444755] [G loss: 0.625674]
[Epoch 11/300] [Batch 1/2] [D loss: 0.451140] [G loss: 0.625349]
[Epoch 12/300] [Batch 0/2] [D loss: 0.461600] [G loss: 0.623004]
[Epoch 12/300] [Batch 1/2] [D loss: 0.475907] [G loss: 0.616745]
[Epoch 13/300] [Batch 0/2] [D loss: 0.492632] [G loss: 0.603465]
[Epoch 13/300] [Batch 1/2] [D loss: 0.509568] [G loss: 0.589097]
[Epoch 14/300] [Batch 0/2] [D loss: 0.524420] [G loss: 0.579904]
[Epoch 14/300] [Batch 1/2] [D loss: 0.532494] [G loss: 0.580591]
[Epoch 15/300] [Batch 0/2] [D loss: 0.533520] [G loss: 0.582400]
[Epoch 15/300] [Batch 1/2] [D loss: 0.528302] [G loss: 0.605238]
[Epoch 16/300] [Batch 0/2] [D loss: 0.520255] [G loss: 0.617953]
[Epoch 16/300] [Batch 1/2] [D loss: 0.512036] [G loss: 0.637511]
[Epoch 17/300] [Batch 0/2] [D loss: 0.506406] [G loss: 0.650202]
[Epoch 17/300] [Batch 1/2] [D loss: 0.507003] [G loss: 0.660099]
[Epoch 18/300] [Batch 0/2] [D loss: 0.515254] [G loss: 0.652070]
[Epoch 18/300] [Batch 1/2] [D loss: 0.531098] [G loss: 0.644829]
[Epoch 19/300] [Batch 0/2] [D loss: 0.546756] [G loss: 0.635194]
[Epoch 19/300] [Batch 1/2] [D loss: 0.557034] [G loss: 0.625947]
[Epoch 20/300] [Batch 0/2] [D loss: 0.562781] [G loss: 0.631091]


... ...

[Epoch 280/300] [Batch 1/2] [D loss: 0.593827] [G loss: 0.486869]
[Epoch 281/300] [Batch 0/2] [D loss: 0.488985] [G loss: 1.372436]
[Epoch 281/300] [Batch 1/2] [D loss: 0.478138] [G loss: 0.782779]
[Epoch 282/300] [Batch 0/2] [D loss: 0.445211] [G loss: 1.119433]
[Epoch 282/300] [Batch 1/2] [D loss: 0.450222] [G loss: 0.956143]
[Epoch 283/300] [Batch 0/2] [D loss: 0.459510] [G loss: 1.030325]
[Epoch 283/300] [Batch 1/2] [D loss: 0.481445] [G loss: 0.955175]
[Epoch 284/300] [Batch 0/2] [D loss: 0.494930] [G loss: 0.967074]
[Epoch 284/300] [Batch 1/2] [D loss: 0.513425] [G loss: 0.888255]
[Epoch 285/300] [Batch 0/2] [D loss: 0.523844] [G loss: 0.933630]
[Epoch 285/300] [Batch 1/2] [D loss: 0.528002] [G loss: 0.829575]
[Epoch 286/300] [Batch 0/2] [D loss: 0.510690] [G loss: 1.025084]
[Epoch 286/300] [Batch 1/2] [D loss: 0.511791] [G loss: 0.752756]
[Epoch 287/300] [Batch 0/2] [D loss: 0.491313] [G loss: 1.258785]
[Epoch 287/300] [Batch 1/2] [D loss: 0.539705] [G loss: 0.594626]
[Epoch 288/300] [Batch 0/2] [D loss: 0.523798] [G loss: 1.575793]
[Epoch 288/300] [Batch 1/2] [D loss: 0.595133] [G loss: 0.471295]
[Epoch 289/300] [Batch 0/2] [D loss: 0.477448] [G loss: 1.566631]
[Epoch 289/300] [Batch 1/2] [D loss: 0.477544] [G loss: 0.706210]
[Epoch 290/300] [Batch 0/2] [D loss: 0.429852] [G loss: 1.241274]
[Epoch 290/300] [Batch 1/2] [D loss: 0.437449] [G loss: 0.915428]
[Epoch 291/300] [Batch 0/2] [D loss: 0.434945] [G loss: 1.074615]
[Epoch 291/300] [Batch 1/2] [D loss: 0.449143] [G loss: 0.958715]
[Epoch 292/300] [Batch 0/2] [D loss: 0.454942] [G loss: 1.042261]
[Epoch 292/300] [Batch 1/2] [D loss: 0.469878] [G loss: 0.926095]
[Epoch 293/300] [Batch 0/2] [D loss: 0.466286] [G loss: 1.058231]
[Epoch 293/300] [Batch 1/2] [D loss: 0.471976] [G loss: 0.851719]
[Epoch 294/300] [Batch 0/2] [D loss: 0.459671] [G loss: 1.190166]
[Epoch 294/300] [Batch 1/2] [D loss: 0.488745] [G loss: 0.692813]
[Epoch 295/300] [Batch 0/2] [D loss: 0.478135] [G loss: 1.635507]
[Epoch 295/300] [Batch 1/2] [D loss: 0.606305] [G loss: 0.428279]
[Epoch 296/300] [Batch 0/2] [D loss: 0.525440] [G loss: 2.067884]
[Epoch 296/300] [Batch 1/2] [D loss: 0.568282] [G loss: 0.477078]
[Epoch 297/300] [Batch 0/2] [D loss: 0.402702] [G loss: 1.508152]
[Epoch 297/300] [Batch 1/2] [D loss: 0.412127] [G loss: 0.929811]
[Epoch 298/300] [Batch 0/2] [D loss: 0.439405] [G loss: 0.930188]
[Epoch 298/300] [Batch 1/2] [D loss: 0.471616] [G loss: 1.025511]
[Epoch 299/300] [Batch 0/2] [D loss: 0.524946] [G loss: 0.700306]
[Epoch 299/300] [Batch 1/2] [D loss: 0.537755] [G loss: 1.130932]

生成的图片

0.png:
在这里插入图片描述
100.png
在这里插入图片描述
200.png
在这里插入图片描述
300.png
在这里插入图片描述

400.png
在这里插入图片描述

500.png
在这里插入图片描述
580.png:
在这里插入图片描述

以上是关于万物皆可 GAN生成对抗网络生成手写数字 Part 2的主要内容,如果未能解决你的问题,请参考以下文章

GAN (生成对抗网络) 手写数字图片生成

生成对抗网络(GAN)详细介绍及数字手写体生成应用仿真(附代码)

GAN-生成对抗网络-生成手写数字(基于pytorch)

对抗生成网络GAN系列——GAN原理及手写数字生成小案例

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天