GAN生成对抗网络----手写数据实现
Posted 醉公子~
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN生成对抗网络----手写数据实现相关的知识,希望对你有一定的参考价值。
目录
GAN------ 以假乱真
GAN 的基本理念其实非常简单,其核心由两个目标互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来“蒙骗”对方。这种情况可以理解为博弈论中的极大极小博弈树。
在这个过程中,我们想象有两类人:警察和罪犯。我们看看他们的之间互相冲突的目标:
- 罪犯的目标:他的主要目标就是想出伪造货币的复杂方法,从而让警察无法区分假币和真币。
- 警察的目标:他的主要目标就是想出辨别货币的复杂方法,这样就能够区分假币和真币。
随着这个过程不断继续,警察会想出越来越复杂的技术来鉴别假币,罪犯也会想出越来越复杂的技术来伪造货币。这就是 GAN 中“对抗过程”的基本理念。
GAN 充分利用“对抗过程”训练两个神经网络,这两个网络会互相博弈直至达到一种理想的平衡状态,我们这个例子中的警察和罪犯就相当于这两个神经网络。
其中一个神经网络叫做生成器网络 G(Z),它会使用输入随机噪声数据,生成和已有数据集非常接近的数据;
另一个神经网络叫鉴别器网络 D(X),它会以生成的数据作为输入,尝试鉴别出哪些是生成的数据,哪些是真实数据。鉴别器的核心是实现二元分类,输出的结果是输入数据来自真实数据集(和合成数据或虚假数据相对)的概率。
我们在前面所说的 GAN 最终能达到一种理想的平衡状态,是指生成器应该能模拟真实的数据,鉴别器输出的概率应该为 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等。
训练流程
环境
- tensorflow 2.4.1
- numpy
- matplotlib
数据集
mnist 手写数字
完整代码
'''
tensorflow 2.4.1
numpy
matplotlib
'''
# 设置GPU内存按需分配
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
import numpy as np
import time
import cv2 as cv
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Activation,Flatten,Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.layers import LeakyReLU, Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam,RMSprop
import matplotlib.pyplot as plt
class ElapsedTimer(object):
def __init__(self):
self.start_time = time.time()
def elapsed(self,sec):
if sec < 60:
return str(sec) + " sec"
elif sec < (60 * 60):
return str(sec / 60) + " min"
else:
return str(sec / (60 * 60)) + " hr"
def elapsed_time(self):
print("Elapsed: %s " % self.elapsed(time.time() - self.start_time) )
class DCGAN(object):
def __init__(self, img_rows=28, img_cols=28, channel=1):
self.img_rows = img_rows
self.img_cols = img_cols
self.channel = channel
self.D = None # discriminator
self.G = None # generator
self.AM = None # adversarial model
self.DM = None # discriminator model
# (W−F+2P)/S+1
# 判别模型
# 14 * 14 * 1
# 返回一个置信度
def discriminator(self):
if self.D:
return self.D
self.D = Sequential()
depth = 64
dropout = 0.4
# In: 28 x 28 x 1, depth = 1
# Out: 14 x 14 x 1, depth=64
input_shape = (self.img_rows, self.img_cols, self.channel) # 14*14*1 的img
"""
padding = “SAME”输入和输出大小关系:
输出大小等于输入大小除以步长向上取整
padding = “VALID”输入和输出大小关系:
输出大小等于输入大小减去滤波器大小加上1,最后再除以步长
"""
"""
64个5*5大小的内核,步长为2,🔠input:(14,14,1),padding=‘same’保证intput和output一样
"""
self.D.add(Conv2D(64, 5, strides=2, input_shape=input_shape,padding='same'))# 14*14*64
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(128, 5, strides=2, padding='same')) # 7*7*128
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(256, 5, strides=2, padding='same')) # 4*4*256 向上取整
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(512, 5, strides=1, padding='same')) # 4*4*512
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(256, 5, strides=1, padding='same')) # 4*4*256
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
# Out: 1-dim probability
self.D.add(Flatten())#扁平 4096=4*4*256
self.D.add(Dense(1)) # 输出 1个
self.D.add(Activation('sigmoid')) # 二分类
self.D.summary()
return self.D
# 生成模型
# 全连接 7*7*256
# 返回一张图 28*28*1
def generator(self):
if self.G:
return self.G
self.G = Sequential()
dropout = 0.4
depth = 64+64+64+64
dim = 7
# In: 100
# Out: dim x dim x depth
self.G.add(Dense(dim*dim*depth, input_dim=100))#全连接 7*7*256 的大小
"""
参数作用于mean和variance的计算上, 这里保留了历史batch里的mean和variance值,即 moving_mean和moving_variance,
借鉴优化算法里的momentum算法将历史batch里的mean和variance的作用延续到当前batch. 一般momentum的值为0.9 , 0.99等.
多个batch后, 即多个0.9连乘后,最早的batch的影响会变弱.
"""
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth))) # 7*7*256
self.G.add(Dropout(dropout))
# In: dim x dim x depth
# Out: 2*dim x 2*dim x depth/2
self.G.add(UpSampling2D()) # 翻倍 14*14*256
"""
输入图像通过卷积操作提取特征后,输出的尺寸常会变小,而有时我们需要将图像恢复到原来的尺寸以便进行进一步的计算(比如:图像的语义分割),
那么我们需要实现图像由小分辨率到大分辨率的映射的操作,叫做上采样(Upsample)。
"""
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) # 反卷积 14*14*128
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())# 28*28*128
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same')) # 28*28*64
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same')) # 28*28*32
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
self.G.add(Conv2DTranspose(1, 5, padding='same')) # 28*28*1 输出一张特征图(就是生成的图像)
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G
def discriminator_model(self):
if self.DM:
return self.DM
optimizer = RMSprop(lr=0.0002, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
# print("DM")
# self.DM.summary()
self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\\
metrics=['accuracy'])
return self.DM
def adversarial_model(self):
if self.AM:
return self.AM
optimizer =RMSprop(lr=0.0001, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
# print('AM')
# self.AM.summary()
self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\\
metrics=['accuracy'])
return self.AM
class MNIST_DCGAN(object):
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channel = 1
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train / 255.0
self.x_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)
self.DCGAN = DCGAN()
self.discriminator = self.DCGAN.discriminator_model()
self.adversarial = self.DCGAN.adversarial_model()
self.generator = self.DCGAN.generator()
def train(self, train_steps=2000, batch_size=256, save_interval=0):
noise_input = None
if save_interval>0:
noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
for i in range(train_steps):
""""
第一轮,由于是没有权重,随机噪声
再后我们对判别器进行训练之后,loss更新,生成器网络权重更新
"""
images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :] # 随机选取128张图像 [128,28,28,1]
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100]) #128,100 的随机【-1,1】之间的数
images_fake = self.generator.predict(noise) # 生成模型训练,图 [128,28,28,1]
"""
图像保存 每5轮保存一次生成器所生成的image
"""
if i%5==0:
plt.figure(figsize=(24, 24))
for j in range(16):
plt.subplot(4, 4, j + 1)
image = images_fake[j, :, :, :]
image = np.reshape(image, [28,28])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
filename = './g/img_'.format(i)
# plt.savefig(filename)
plt.close('all')
""""
在鉴别器的训练过程中,它显示为真实图像,并用于计算鉴别器损耗。
它对来自生成器的真实和伪造图像进行分类,如果对任何图像进行了不正确分类,则鉴别器损失将对鉴别器进行惩罚。
通过反向传播,鉴别器更新其权重
类似地,为生成器提供了噪声输入以生成伪图像。 这些图像被提供给鉴别器,并且发生器损失惩罚了发生器以产生鉴别器网络分类为伪造的样本。
权重通过从鉴别器到生成器的反向传播进行更新
"""
x = np.concatenate((images_train, images_fake)) #256*28*28*1 维度相加 数组拼接(将训练图片与生成的向量拼接), axis=0 按照行拼接。axis=1 按照列拼接,默认0
print('4',x.shape)
y = np.ones([2*batch_size, 1]) # 生成(256,1)的全是1的数组
y[batch_size:, :] = 0 # 256*1 第128-256行的所有列全为0
d_loss = self.discriminator.train_on_batch(x, y)#鉴别
"""
核心
"""
y = np.ones([batch_size, 1]) # 128*1
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100]) #128*100
a_loss = self.adversarial.train_on_batch(noise, y)
log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
print(log_mesg)
if save_interval>0:
if (i+1)%save_interval==0:
self.plot_images(save2file=True, samples=noise_input.shape[0],\\
noise=noise_input, step=(i+1))
def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
filename = 'mnist.png'
if fake:
if noise is None:
noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
else:
filename = "mnist_%d.png" % step
images = self.generator.predict(noise)
else:
i = np.random.randint(0, self.x_train.shape[0], samples)
images = self.x_train[i, :, :, :]
plt.figure(figsize=(10,10))
for i in range(images.shape[0]):
plt.subplot(4, 4, i+1)
image = images[i, :, :, :]
image = np.reshape(image, [self.img_rows, self.img_cols])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
# if save2file:
# plt.savefig(filename)
# plt.close('all')
# else:
# plt.show()
if __name__ == '__main__':
mnist_dcgan = MNIST_DCGAN()
timer = ElapsedTimer()
mnist_dcgan.train(train_steps=10000, batch_size=128, save_interval=1000)
timer.elapsed_time()
mnist_dcgan.plot_images(fake=True)
mnist_dcgan.plot_images(fake=False, save2file=True)
结果展示
【参考文献】
https://www.cnblogs.com/dereen/p/gan.html
https://zhuanlan.zhihu.com/p/43047326
https://www.zhihu.com/question/306213462
以上是关于GAN生成对抗网络----手写数据实现的主要内容,如果未能解决你的问题,请参考以下文章