ACGAN 生成自己手写数字数据集

Posted __不想写代码__

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ACGAN 生成自己手写数字数据集相关的知识,希望对你有一定的参考价值。


前言

由于有可能使用GAN 网络来做一些数据增强,所以这里复现一下GAN 网络,发现这玩意儿还挺好玩。

一、GAN是什么?

GAN (Generative Adversarial Networks)生成对抗网络,用来生成一下不存在的真实数据。应用场景如下:
1.风格迁移:也就是传说中的AI 画家
2.图像超分辨率重建: 让图像更加清晰
3.生成不存在的真实数据:人脸生成等~

根据训练时带不带标签,GAN 网络是可分为无监督和半监督式的网络。GAN
网络分为两部分,Generator (生成器,图中G)和 Discriminator (判别器,图中D)…
随机生成的噪声,通过生成器,生成我们想要的数据,然后把这个数据和真实数据一起送入到判别器中判断,如果判别器认为输入的是生成数据,那么久训练判别器,如果判别器把生成的数据认为是真的数据,那么就要训练判别器啦~,生成器与判别器两者之间相互博弈,最后让生成器能够成功的欺骗过判别器,那么就可以使用生成器来生成想要的数据啦。

根据前人经验,生成器中的激活函数一般用relu。判别器中的激活函数一般用LeakyReLU

二、ACGAN

1.ACGAN 网络结构

由于ACCGAN 是带有标签的GAN 如果训练得当,应该可以生成想要的数据。看看它的网络结构:

图中,输入到 生成器中的标签 C 和 Z 是随机生成的,但一般都要符合正态分布,生成器生成的假数据,将和真实数据一起输入到判别器中进行判断,真实数据的label 将和判别器输出的label 做损失计算,另一端的输出,只需要判断真假就好。

2.Generator 生成器实现

代码如下:

    def built_generator(self):
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        # model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh'))

        model.summary()

        # -----------------
        # 生成噪声
        # -----------------、
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        # print(Embedding(self.num_classes, self.latent_dim)(label).shape)
        model_input = multiply([noise, label_embedding])

        img = model(model_input)

        return Model([noise, label], img)

关于生成器中的参数设置,首先是全连接 7x7x128, 由于手写数字 图片大小为28x28,初始大小设为7x7 后续会通过2次上采样,就会变成14x14 再由14x14 变为28x28 ,还原图片的大小。

注意:如果要训练自己的图片数据,记得计算好图片大小和上采样的次数,每次上采样,特征图会扩大到原来的两倍

3.Discriminator 判别器实现

    def built_discriminator(self):
        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model.add(Conv2D(32, kernel_size=3, strides=2, padding='same'))
        model.add(ZeroPadding2D(padding=((0, 1), (1, 0))))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(128, kernel_size=3, strides=1, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.summary()
        img = Input(shape=self.img_shape)

        features = model(img)

        validity = Dense(1, activation='sigmoid')(features)

        label = Dense(self.num_classes, activation='softmax')(features)

        return Model(img, [validity, label])

判别器跟普通的卷积网络区别不大,输入的是生成的图片,同样通过卷积来提取特征,只是一个输出判别真假,另一个输出判别标签。

4. 完整代码

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
class ACGAN():
    def __init__(self, img_rows=28, img_cols=28, n_channels=1, num_classes=10):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = n_channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = num_classes
        self.latent_dim = 100
        optimizer = Adam(0.0002, 0.5)
        losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']

        self.discriminator = self.built_discriminator()
        self.discriminator.compile(loss=losses, optimizer=optimizer, metrics=['acc'])

        self.generator = self.built_generator()
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        self.discriminator.trainable = False

        valid, target_label = self.discriminator(img)

        self.combined = Model([noise, label], [valid, target_label])
        self.combined.compile(loss=losses, optimizer=optimizer)

    def built_generator(self):
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        # model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh'))

        model.summary()

        # -----------------
        # 生成噪声
        # -----------------、
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        # print(Embedding(self.num_classes, self.latent_dim)(label).shape)
        model_input = multiply([noise, label_embedding])

        img = model(model_input)

        return Model([noise, label], img)

    def built_discriminator(self):
        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model.add(Conv2D(32, kernel_size=3, strides=2, padding='same'))
        model.add(ZeroPadding2D(padding=((0, 1), (1, 0))))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(128, kernel_size=3, strides=1, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.summary()
        img = Input(shape=self.img_shape)

        features = model(img)

        validity = Dense(1, activation='sigmoid')(features)

        label = Dense(self.num_classes, activation='softmax')(features)

        return Model(img, [validity, label])

    def train(self, epochs, batch_size, sample_interval=50):
        (X_train, y_train), (_, _) = mnist.load_data()
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # 归一化
        # (60000, 28, 28) -> (60000, 28, 28,1)
        X_train = np.expand_dims(X_train, axis=3)
        # (60000,) -> (60000,1)
        y_train = y_train.reshape(-1, 1)

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            sampled_labels = np.random.randint(0, 10, (batch_size, 1))

            gen_imgs = self.generator.predict([noise, sampled_labels])

            img_labels = y_train[idx]

            d_loss_real = self.discriminator.train_on_batch(imgs, [valid, img_labels])
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, sampled_labels])

            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            g_loss = self.combined.train_on_batch([noise, sampled_labels], [valid, sampled_labels])

            print("%d [D loss: %f, acc.: %.2f%%, op_acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[3], 100 * d_loss[4], g_loss[0]))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.save_model()
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 10, 10
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        sampled_labels = np.array([num for _ in range(r) for num in range(c)])
        gen_imgs = self.generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

    def save_model(self):

        def save(model, model_name):
            model_path = "saved_model/%s.json" % model_name
            weights_path = "saved_model/%s_weights.hdf5" % model_name
            options = {"file_arch": model_path,
                       "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "generator")
        save(self.discriminator, "discriminator")

if __name__ == '__main__':
    # acgan = ACGAN()
    # acgan.built_generator()
    # acgan.built_discriminator().summary()
    acgan = ACGAN()
    acgan.train(epochs=14000, batch_size=1024, sample_interval=200)

网络的输入输出 可以根据图片再琢磨一下~确实有点难理解。
初始化的效果:

训练了1000epoch的效果:

训练了2000个epoch 的效果:
<

以上是关于ACGAN 生成自己手写数字数据集的主要内容,如果未能解决你的问题,请参考以下文章

Keras搭建ACGAN生成MNIST手写体图片

Python 3 生成手写体数字数据集

利用mnist训练集生成的caffemodel对mnist测试集与自己手写的数字进行测试

机器学习初探(手写数字识别)matlab读取数据集

深度学习入门实战----利用神经网络识别自己的手写数字

[Pytorch系列-61]:生成对抗网络GAN - 基本原理 - 自动生成手写数字案例分析