Keras运行GAN实例(2022.2.25)
Posted jing_zhong
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Keras运行GAN实例(2022.2.25)相关的知识,希望对你有一定的参考价值。
Keras运行GAN实例 2022.2.25
- 1、GAN简介
- 2、测试环境(Win11 64位 + GTX 1050Ti +CUDA 10.1 + cudnn 7.6.5 + Python 3.6 + tensorflow-gpu 2.3.1)
- 3、实例运行
- 3.1 [Adversarial Autoencoder](https://arxiv.org/abs/1511.05644)
- 3.2 [Auxiliary Classifier Generative Adversarial Network](https://arxiv.org/abs/1610.09585)
- 3.3 [Bidirectional Generative Adversarial Network](https://arxiv.org/abs/1605.09782)
- 3.4 [Boundary-Seeking Generative Adversarial Networks](https://arxiv.org/abs/1702.08431)
- 3.5 [Context-Conditional Generative Adversarial Networks](https://arxiv.org/abs/1611.06430)
- 3.6 [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
- 3.7 [Coupled generative adversarial networks](https://arxiv.org/abs/1606.07536)
- 3.8 [Context Encoders: Feature Learning by Inpainting](https://arxiv.org/abs/1604.07379)
- 3.9 [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
- 3.10 [Deep Convolutional Generative Adversarial Network](https://arxiv.org/abs/1511.06434)
- 3.11 [Discover Cross-Domain Relations with Generative Adversarial Networks](https://arxiv.org/abs/1703.05192)
- 3.12 [DualGAN: Unsupervised Dual Learning for Image-to-Image Translation](https://arxiv.org/abs/1704.02510)
- 3.13 [Generative Adversarial Network](https://arxiv.org/abs/1406.2661)
- 3.14 [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657)
- 3.15 [Least Squares Generative Adversarial Networks](https://arxiv.org/abs/1611.04076)
- 3.16 [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)
- 3.17 [Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial Networks](https://arxiv.org/abs/1612.05424)
- 3.18 [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
- 4、应用总结
1、GAN简介
GAN
(Generative Adversarial Nets)是由Ian J. Goodfellow
等人在2014NIPS
会议上提出的一种网络,他们提出了一种利用对抗处理来估计生成式模型的新框架,该框架可同时训练两个模型:一个是生成模型G(Generator生成器),可等同于一个函数G(z)
;另一个是判别模型D(Descriminator判别器),也可等同于一个函数D(x)
。
生成器G
用于捕捉数据的分布规律从而生成新的假样本数据;判别器D
用来估计一个样本来自真实训练数据而非生成器G
。因此生成器Generator
的训练过程就是要最大化判别器Descriminator
的犯错概率,相当于生成器Generator要制造假样本,使得假样本尽可能地欺骗判别器Descriminator
,让判别器Descriminator
无法区分假样本到底是来自真实训练数据还是生成器,因此这个框架可认为是一个两个玩家间G
最小D
最大的游戏:G玩家要努力生成与真实样本尽可能接近的假样本,D玩家要准确判断一个样本是来自于真实数据还是G玩家,两个玩家在进行对抗,整个网络在进行对抗训练,从而最终让G玩家预测或生成尽可能新的样本数据。在任意的函数空间G和D,存在唯一解当且仅当G恢复了训练数据的分布,而此时D只等于
1
2
\\frac12
21。
1.1 理论分析
GAN
的原理就是用生成器G和判别器D进行对抗,本质上目的是去做预测和生成,要预测输入真实样本数据x
的概率分布p(x)
,预测的
x
—
—
输
入
的
真
实
样
本
数
据
x——输入的真实样本数据
x——输入的真实样本数据
z
—
—
噪
音
数
据
z——噪音数据
z——噪音数据
G
(
z
)
—
—
生
成
器
G(z)——生成器
G(z)——生成器
D
(
x
)
—
—
判
别
器
D(x)——判别器
D(x)——判别器
x
′
=
G
(
z
)
—
—
生
成
器
G
根
据
噪
音
z
生
成
的
假
样
本
x^' =G(z) ——生成器G根据噪音z 生成的假样本
x′=G(z)——生成器G根据噪音z生成的假样本
p
d
a
t
a
—
—
真
实
样
本
数
据
的
概
率
分
布
p_data——真实样本数据的概率分布
pdata——真实样本数据的概率分布
p
z
—
—
噪
音
数
据
的
先
验
概
率
分
布
,
由
生
成
器
G
来
隐
含
定
义
p_z——噪音数据的先验概率分布,由生成器G来隐含定义
pz——噪音数据的先验概率分布,由生成器G来隐含定义
p
g
—
—
生
成
器
G
预
估
真
实
样
本
数
据
的
概
率
分
布
p_g——生成器G预估真实样本数据的概率分布
pg——生成器G预估真实样本数据的概率分布
网
络
训
练
的
全
局
优
化
目
标
为
:
p
g
=
p
d
a
t
a
网络训练的全局优化目标为:p_g = p_data
网络训练的全局优化目标为:pg=pdata
1.2 优缺点
GAN
的缺点主要包括:无法获得生成器预测样本数据概率分布
p
g
(
x
)
p_g(x)
pg(x)的明确显式表达;在训练期间,判别器D必须与生成器G保持较好的同步和更新。
GAN的
优势有:训练过程无需干预,大量函数都可应用到此模型,可计算性强,能够表示非常尖锐甚至退化的分布。进一步学习GAN
可参考:
- Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[J]. Advances in neural information processing systems, 2014, 27.
- https://speech.ee.ntu.edu.tw/~tlkagk/slide/Tutorial_HYLee_GAN.pdf
- https://tensorflow.google.cn/tutorials/generative/style_transfer?hl=zh_cn
- https://keras.io/examples/generative/
- https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
2、测试环境(Win11 64位 + GTX 1050Ti +CUDA 10.1 + cudnn 7.6.5 + Python 3.6 + tensorflow-gpu 2.3.1)
本文测试环境是在Win11 64位
操作系统上进行的,显卡为GTX 1050Ti
,安装了CUDA 10.1
和cudnn7.6.5
,最后利用Anaconda
创建Python 3.6
的虚拟环境,利用pip
安装了所需的依赖包tensorflow-gpu 2.3.1
。
python 3.6
所安装的依赖包如下(pip install 包名==版本号
):
absl-py==1.0.0
astunparse==1.6.3
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.12
cycler==0.11.0
dataclasses==0.8
gast==0.3.3
google-auth==2.6.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.43.0
h5py==2.10.0
idna==3.3
importlib-metadata==4.8.3
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.6
matplotlib==3.3.4
numpy==1.18.5
oauthlib==3.2.0
opt-einsum==3.3.0
pandas==1.1.5
Pillow==8.4.0
protobuf==3.19.4
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==3.0.7
python-dateutil==2.8.2
pytz==2021.3
requests==2.27.1
requests-oauthlib==1.3.1
rsa==4.8
scipy==1.2.1
six==1.16.0
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow-addons==0.14.0
tensorflow-gpu==2.3.1
tensorflow-gpu-estimator==2.3.0
termcolor==1.1.0
typeguard==2.13.3
typing_extensions==4.1.1
urllib3==1.26.8
Werkzeug==2.0.3
wincertstore==0.2
wordcloud==1.8.1
wrapt==1.12.1
zipp==3.4.0
3、实例运行
所用实例均来自于eriklindernoren/Keras-GAN,为了进一步学习各种GAN的效果,尝试用示例代码运行,源码中训练迭代次数的值通常很大(3000~30000
),由于电脑性能较差特意减少了迭代次数(1000~2000
)来学习,因而运行结果可能不够好,无法证明模型效果不佳。
3.1 Adversarial Autoencoder
代码及运行结果
MergeLayer.py
from tensorflow.keras.layers import Layer
import tensorflow.keras.backend as K
class MergeLayer(Layer):
def __init__(self, **kwargs):
super(MergeLayer, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
return (input_shape[0][0], input_shape[0][1])
def call(self, x, mask=None):
final_output = x[0] + K.random_normal(K.shape(x[0])) * K.exp(x[1] / 2)
return final_output
aae.py
from __future__ import print_function, division
from MergeLayer import MergeLayer
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from tensorflow.keras.layers import MaxPooling2D, Concatenate
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import losses
from tensorflow.keras.utils import to_categorical
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
class AdversarialAutoencoder():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 10
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the encoder / decoder
self.encoder = self.build_encoder()
self.decoder = self.build_decoder()
img = Input(shape=self.img_shape)
# The generator takes the image, encodes it and reconstructs it
# from the encoding
encoded_repr = self.encoder(img)
reconstructed_img = self.decoder(encoded_repr)
# For the adversarial_autoencoder model we will only train the generator
self.discriminator.trainable = False
# The discriminator determines validity of the encoding
validity = self.discriminator(encoded_repr)
# The adversarial_autoencoder model (stacked generator and discriminator)
self.adversarial_autoencoder = Model(img, [reconstructed_img, validity])
self.adversarial_autoencoder.compile(loss=['mse', 'binary_crossentropy'],
loss_weights=[0.999, 0.001],
optimizer=optimizer)
def build_encoder(self):
# Encoder
img = Input(shape=self.img_shape)
h = Flatten()(img)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
mu = Dense(self.latent_dim)(h)
log_var = Dense(self.latent_dim)(h)
latent_repr = MergeLayer()([mu, log_var])
return Model(img, latent_repr)
def build_decoder(self):
model = Sequential()
model.add(Dense(512, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
z = Input(shape=(self.latent_dim,))
img = model(z)
return Model(z, img)
def build_discriminator(self):
model = Sequential()
model.add(Dense(512, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation="sigmoid"))
model.summary()
encoded_repr = Input(shape=(self.latent_dim, ))
validity = model(encoded_repr)
return Model(encoded_repr, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch 以上是关于Keras运行GAN实例(2022.2.25)的主要内容,如果未能解决你的问题,请参考以下文章
这些资源你肯定需要!超全的GAN PyTorch+Keras实现集合
python model.trainable = False如何在keras中工作(GAN模型)