搭建简单GAN生成MNIST手写体

Posted Paul-Huang

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了搭建简单GAN生成MNIST手写体相关的知识,希望对你有一定的参考价值。

Keras搭建GAN生成MNIST手写体

GAN简介

  1. 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

  2. 在GAN模型中,一般存在两个模块:
    分别是生成模型(Generative Model)和判别模型(Discriminative Model);二者的互相博弈与学习将会产生相当好的输出。

    原始 GAN 理论中,并不要求生成模型和判别模型都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为生成模型和判别模型 。

  3. GAN的训练方法
    其实简单来讲,一般情况下,GAN就是创建两个神经网络,一个是生成模型,一个是判别模型。

    • 生成模型 输 入 \\color{red}输入 是一行正态分布随机数, 输 出 \\color{red}输出 可以被认为是一张图片(或者其它需要被判定真伪的东西)。
    • 判别模型 输 入 \\color{red}输入 是一张图片(或者其它需要被判定真伪的东西), 输 出 \\color{red}输出 是输入进来的图片是否是真实的(0或者1)。

MNIST数据搭建网络

1. Generator

  • 输入一行正态分布随机数,生成mnist手写体图片,因此它的输入是一个长度为N的一维的向量;

  • 输出一个28,28,1维的图片。

    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential(name='generator')
    
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.1))
        model.add(BatchNormalization(momentum=0.8))
    
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
    
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
    
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
    	
        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
    
        return Model(noise, img)
    

2、Discriminator

判别模型的目的是根据输入的图片判断出真伪。

  • 输入一个28,28,1维的图片;
  • 输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。
def build_discriminator(self):
    # ----------------------------------- #
    #   评价器,对输入进来的图片进行评价
    # ----------------------------------- #
    model = Sequential(name='discriminator')
    # 输入一张图片
    model.add(Flatten(input_shape=self.img_shape))
    model.add(Dense(64))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.2))
    # 判断真伪
    model.add(Dense(1, activation='sigmoid'))
	# 判断输入图片的类型(0/1)
    img = Input(shape=self.img_shape)
    validity = model(img)

    return Model(img, validity)

3. 初始化GAN模型

  1. 创建判别器生成器
  2. 训练生成器(generate)的时候不训练判别器(discriminator)
  3. 对生成的假图片进行预测
def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

4. 训练环节

GAN的训练分为如下几个步骤:
1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、将真实图片和虚假图片当作训练集传入到Discriminator中进行训练(真实图片的label为1,虚假图片的label为0)。
4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”)。

# gan.train(epochs=60000, batch_size=256, sample_interval=200)
 def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签,真实图片的label为1,虚假图片的label为0
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            # 随机选取batch_size个真实的图片
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
			# 随机生成batch_size个N维向量,
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
			# 传入到Generator中生成batch_size个虚假图片
            gen_imgs = self.generator.predict(noise)
			# 将真实图片和虚假图片当作训练集传入到Discriminator中进行训练
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            # 将虚假图片的Discriminator预测结果与1的对比,作为loss对Generator进行训练
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

整体代码

__author__ = 'HQR'
from __future__ import print_function, division

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(64))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(128))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
			# --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            # 随机选取batch_size个真实的图片
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
			# 随机生成batch_size个N维向量,
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
			# 传入到Generator中生成batch_size个虚假图片
            gen_imgs = self.generator.predict(noise)
			# 将真实图片和虚假图片当作训练集传入到Discriminator中进行训练
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            # 将虚假图片的Discriminator预测结果与1的对比,作为loss对Generator进行训练
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=60000, batch_size=256, sample_interval=200)

30000的结果:

GAN存在问题

参考参考文献3。

参考:

  1. 好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体
  2. GAN网络生成手写体数字图片
  3. 生成对抗网络GAN详细推导和训练注意事项

以上是关于搭建简单GAN生成MNIST手写体的主要内容,如果未能解决你的问题,请参考以下文章

第一节2:GAN经典案例之MNIST手写数字生成

gan如何做图像增强

手写数字识别——基于全连接层和MNIST数据集

不要怂,就是GAN (生成式对抗网络)

Tensorflow - Tutorial : GAN生成图片

Tensorflow - Tutorial : GAN生成图片