换脸火了,我用 Python 快速入门生成模型
Posted CSDN
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了换脸火了,我用 Python 快速入门生成模型相关的知识,希望对你有一定的参考价值。
引言:
近几年来,GAN 生成对抗式应用十分火热,不论是抖音上大火的“蚂蚁牙黑”还是B站上的“复原老旧照片”以及换脸等功能,都是基于 GAN 生成对抗式的模型。但是 GAN 算法对于大多数而言上手较难,故今天我们将使用最少的代码,简单入门“生成对抗式网络”,实现用 GAN 生成数字。
其中生成的图片效果如下可见:
OS 模块用来对本地文件读写删除、查找到等文件操作
numpy 模块用来矩阵和数据的运算处理,其中也包括和深度学习框架之间的交互等
Keras 模块是一个由 Python 编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK 和 Theano 的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。在这里我们用来搭建网络层和直接读取数据集操作,简单方便
Matplotlib 模块用来可视化训练效果等数据图的制作
def __init__(self, width=28, height=28, channels=1):
self.width = width
self.height = height
self.channels = channels
self.shape = ( self.width, self.height, self.channels)
self.optimizer = Adam(lr= 0. 0002, beta_1= 0. 5, decay= 8e- 8)
self.G = self.__generator()
self.G.compile(loss= 'binary_crossentropy', optimizer= self.optimizer)
self.D = self.__discriminator()
self.D.compile(loss= 'binary_crossentropy', optimizer= self.optimizer, metrics=[ 'accuracy'])
self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
self.stacked_generator_discriminator.compile(loss= 'binary_crossentropy', optimizer= self.optimizer)
1.4 生成器模型的搭建
def __generator(self):
""" Declare generator """
model = Sequential()
model. add(Dense( 256, input_shape=( 100,)))
model. add(LeakyReLU(alpha= 0.2))
model. add(BatchNormalization(momentum= 0.8))
model. add(Dense( 512))
model. add(LeakyReLU(alpha= 0.2))
model. add(BatchNormalization(momentum= 0.8))
model. add(Dense( 1024))
model. add(LeakyReLU(alpha= 0.2))
model. add(BatchNormalization(momentum= 0.8))
model. add(Dense(self.width * self.height * self.channels, activation= 'tanh'))
model. add(Reshape((self.width, self.height, self.channels)))
return model
1.5 判别器模型的搭建
def __discriminator(self):
""" Declare discriminator """
model = Sequential()
model.add(Flatten(input_shape= self.shape))
model.add(Dense(( self.width * self.height * self.channels), input_shape= self.shape))
model.add(LeakyReLU(alpha= 0. 2))
model.add(Dense(np.int64(( self.width * self.height * self.channels)/ 2)))
model.add(LeakyReLU(alpha= 0. 2))
model.add(Dense( 1, activation= 'sigmoid'))
model.summary()
return model
1.6 对抗式模型的搭建
def __stacked_generator_discriminator(self):
self.D.trainable = False
model = Sequential()
model.add( self.G)
model.add( self.D)
return model
def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
for cnt in range(epochs):
## train discriminator
random_index = np.random.randint( 0, len(X_train) - np.int64(batch/ 2))
legit_images = X_train[random_index : random_index + np.int64(batch/ 2)].reshape(np.int64(batch/ 2), self.width, self.height, self.channels)
gen_noise = np.random.normal( 0, 1, (np.int64(batch/ 2), 100))
syntetic_images = self.G.predict(gen_noise)
x_combined_batch = np.concatenate((legit_images, syntetic_images))
y_combined_batch = np.concatenate((np.ones((np.int64(batch/ 2), 1)), np.zeros((np.int64(batch/ 2), 1))))
d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
# train generator
noise = np.random.normal( 0, 1, (batch, 100))
y_mislabled = np.ones((batch, 1))
g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
print ( 'epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[ 0], g_loss))
if cnt % save_interval == 0:
self.plot_images(save2file=True, step=cnt)
2.2 可视化
def plot_images(self, save2file=False, samples=16, step=0):
''' Plot and generated images '''
if not os.path.exists( "./images"):
os.makedirs( "./images")
filename = "./images/mnist_%d.png" % step
noise = np.random.normal( 0, 1, (samples, 100))
images = self.G.predict(noise)
plt.figure(figsize=( 10, 10))
for i in range(images.shape[ 0]):
plt.subplot( 4, 4, i+ 1)
image = images[i, :, :, :]
image = np.reshape(image, [self.height, self.width])
plt.imshow(image, cmap= 'gray')
plt.axis( 'off')
plt.tight_layout()
if save2file:
plt.savefig(filename)
plt.close( 'all')
else:
plt.show()
# -*- coding: utf-8 -*-
import os
import numpy as np
from IPython.core.debugger import Tracer
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam
import matplotlib.pyplot as plt
plt.switch_backend( 'agg') # allows code to run without a system DISPLAY
class GAN(object):
""" Generative Adversarial Network class """
def __init__(self, width=28, height=28, channels=1):
self.width = width
self.height = height
self.channels = channels
self.shape = ( self.width, self.height, self.channels)
self.optimizer = Adam(lr= 0. 0002, beta_1= 0. 5, decay= 8e- 8)
self.G = self.__generator()
self.G.compile(loss= 'binary_crossentropy', optimizer= self.optimizer)
self.D = self.__discriminator()
self.D.compile(loss= 'binary_crossentropy', optimizer= self.optimizer, metrics=[ 'accuracy'])
self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
self.stacked_generator_discriminator.compile(loss= 'binary_crossentropy', optimizer= self.optimizer)
def __generator(self):
""" Declare generator """
model = Sequential()
model.add(Dense( 256, input_shape=( 100,)))
model.add(LeakyReLU(alpha= 0. 2))
model.add(BatchNormalization(momentum= 0. 8))
model.add(Dense( 512))
model.add(LeakyReLU(alpha= 0. 2))
model.add(BatchNormalization(momentum= 0. 8))
model.add(Dense( 1024))
model.add(LeakyReLU(alpha= 0. 2))
model.add(BatchNormalization(momentum= 0. 8))
model.add(Dense( self.width * self.height * self.channels, activation= 'tanh'))
model.add(Reshape(( self.width, self.height, self.channels)))
return model
def __discriminator(self):
""" Declare discriminator """
model = Sequential()
model.add(Flatten(input_shape= self.shape))
model.add(Dense(( self.width * self.height * self.channels), input_shape= self.shape))
model.add(LeakyReLU(alpha= 0. 2))
model.add(Dense(np.int64(( self.width * self.height * self.channels)/ 2)))
model.add(LeakyReLU(alpha= 0. 2))
model.add(Dense( 1, activation= 'sigmoid'))
model.summary()
return model
def __stacked_generator_discriminator(self):
self.D.trainable = False
model = Sequential()
model.add( self.G)
model.add( self.D)
return model
def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
for cnt in range(epochs):
## train discriminator
random_index = np.random.randint( 0, len(X_train) - np.int64(batch/ 2))
legit_images = X_train[random_index : random_index + np.int64(batch/ 2)].reshape(np.int64(batch/ 2), self.width, self.height, self.channels)
gen_noise = np.random.normal( 0, 1, (np.int64(batch/ 2), 100))
syntetic_images = self.G.predict(gen_noise)
x_combined_batch = np.concatenate((legit_images, syntetic_images))
y_combined_batch = np.concatenate((np.ones((np.int64(batch/ 2), 1)), np.zeros((np.int64(batch/ 2), 1))))
d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
# train generator
noise = np.random.normal( 0, 1, (batch, 100))
y_mislabled = np.ones((batch, 1))
g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
print ( 'epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[ 0], g_loss))
if cnt % save_interval == 0:
self.plot_images(save2file=True, step=cnt)
def plot_images(self, save2file=False, samples=16, step=0):
''' Plot and generated images '''
if not os.path.exists( "./images"):
os.makedirs( "./images")
filename = "./images/mnist_%d.png" % step
noise = np.random.normal( 0, 1, (samples, 100))
images = self.G.predict(noise)
plt.figure(figsize=( 10, 10))
for i in range(images.shape[ 0]):
plt.subplot( 4, 4, i+ 1)
image = images[i, :, :, :]
image = np.reshape(image, [ self.height, self.width])
plt.imshow(image, cmap= 'gray')
plt.axis( 'off')
plt.tight_layout()
if save2file:
plt.savefig(filename)
plt.close( 'all')
else:
plt.show()
if __name_ _ == '__main__':
(X_train, _), ( _, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis= 3)
gan = GAN()
gan.train(X_train)
60+专家,13个技术领域,CSDN 《IT 人才成长路线图》重磅来袭!
以上是关于换脸火了,我用 Python 快速入门生成模型的主要内容,如果未能解决你的问题,请参考以下文章
AI换脸在电竞圈火了!大司马PDD大秀肌肉辣舞,网友:上头,流鼻血了
Frechlet Inception Distance(FID)快速入门使用代码