Context_Encoder在mnist的实战
Posted nanhaijindiao
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Context_Encoder在mnist的实战相关的知识,希望对你有一定的参考价值。
Context_Encoder是一种基于GAN的人脸修复框架,后面附带了简单地的理论讲解。论文中人脸照片被攻击的方式有三种:在图片(矩阵)中扣一个正方形,让正方形的数字变成0;在图片中任意扣除n个正方形,让正方形中的数字变成0;最后一种是让图片中(矩阵)中任意的一些数字变成0.第三种才是大家比较喜欢的,也是最接近现实的。keras的官方教程给出了3通道的cifar(三维数据)的人脸修复代码,修复的也是第一种攻击方式。在这个代码的基础上,我将其修改到了mnist二维数据集的人脸修复上。
理论讲解:
1.https://blog.csdn.net/qq_33594380/article/details/85317922
2.https://www.cnblogs.com/wmr95/p/10636804.html
keras的cifar的教程:
https://github.com/eriklindernoren/Keras-GAN/blob/master/context_encoder/context_encoder.py
以下是我修改后的代码:
from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D from keras.layers import MaxPooling2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam from keras import losses from keras.utils import to_categorical import keras.backend as K import matplotlib.pyplot as plt import numpy as np class ContextEncoder(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.mask_height = 8 self.mask_width = 8 self.channels = 1 self.num_classes = 2 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.missing_shape = (self.mask_height, self.mask_width, self.channels) optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss=‘binary_crossentropy‘, optimizer=optimizer, metrics=[‘accuracy‘]) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates the missing # part of the image masked_img = Input(shape=self.img_shape) gen_missing = self.generator(masked_img) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines # if it is generated or if it is a real image valid = self.discriminator(gen_missing) # The combined model (stacked generator and discriminator) # Trains generator to fool discriminator self.combined = Model(masked_img , [gen_missing, valid]) self.combined.compile(loss=[‘mse‘, ‘binary_crossentropy‘], loss_weights=[0.999, 0.001], optimizer=optimizer) def build_generator(self): model = Sequential() # Encoder model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(512, kernel_size=1, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.5)) # Decoder model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=3, padding="same")) model.add(Activation(‘relu‘)) model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding="same")) model.add(Activation(‘relu‘)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(self.channels, kernel_size=3, padding="same")) model.add(Activation(‘tanh‘)) model.summary() masked_img = Input(shape=self.img_shape) gen_missing = model(masked_img) return Model(masked_img, gen_missing) def build_discriminator(self): model = Sequential() model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.missing_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(256, kernel_size=3, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Flatten()) model.add(Dense(1, activation=‘sigmoid‘)) model.summary() img = Input(shape=self.missing_shape) validity = model(img) return Model(img, validity) def mask_randomly(self, imgs): y1 = np.random.randint(0, self.img_rows - self.mask_height, imgs.shape[0]) y2 = y1 + self.mask_height x1 = np.random.randint(0, self.img_rows - self.mask_width, imgs.shape[0]) x2 = x1 + self.mask_width masked_imgs = np.empty_like(imgs) missing_parts = np.empty((imgs.shape[0], self.mask_height, self.mask_width, self.channels)) for i, img in enumerate(imgs): masked_img = img.copy() _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i] missing_parts[i] = masked_img[_y1:_y2, _x1:_x2, :].copy() masked_img[_y1:_y2, _x1:_x2, :] = 0 masked_imgs[i] = masked_img return masked_imgs, missing_parts, (y1, y2, x1, x2) def train(self, epochs, batch_size=128, sample_interval=50): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() # Extract dogs and cats X_cats = X_train[(y_train == 3).flatten()] X_dogs = X_train[(y_train == 5).flatten()] X_train = np.vstack((X_cats, X_dogs)) X_train = X_train.reshape(-1,28,28,1) # Rescale -1 to 1 X_train = X_train / 127.5 - 1. y_train = y_train.reshape(-1, 1) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] masked_imgs, missing_parts, _ = self.mask_randomly(imgs) # Generate a batch of new images gen_missing = self.generator.predict(masked_imgs) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(missing_parts, valid) d_loss_fake = self.discriminator.train_on_batch(gen_missing, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- g_loss = self.combined.train_on_batch(masked_imgs, [missing_parts, valid]) # Plot the progress print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1])) # If at save interval => save generated image samples if epoch % sample_interval == 0: idx = np.random.randint(0, X_train.shape[0], 6) imgs = X_train[idx] self.sample_images(epoch, imgs) def plot_image(self,image): fig=plt.gcf() fig.set_size_inches(2,2) plt.imshow(image,cmap=‘binary‘) plt.show() def sample_images(self, epoch, imgs): masked_imgs, missing_parts, (y1, y2, x1, x2) = self.mask_randomly(imgs) gen_missing = self.generator.predict(masked_imgs) imgs = 0.5 * imgs + 0.5 #完整图片 masked_imgs = 0.5 * masked_imgs + 0.5 #残缺图片 gen_missing = 0.5 * gen_missing + 0.5 #模拟的缺失值 filled_in = imgs[1].copy() filled_in[y1[1]:y2[1], x1[1]:x2[1], :] = gen_missing[1] #print("数组的维度",imgs.shape) imgs = imgs.reshape(-1,28,28) masked_imgs = masked_imgs.reshape(-1,28,28) filled_in = filled_in.reshape(-1,28,28) self.plot_image(imgs[1]) self.plot_image(masked_imgs[1]) self.plot_image(filled_in[0]) plt.close() def save_model(self): def save(model, model_name): model_path = "saved_model/%s.json" % model_name weights_path = "saved_model/%s_weights.hdf5" % model_name options = {"file_arch": model_path, "file_weight": weights_path} json_string = model.to_json() open(options[‘file_arch‘], ‘w‘).write(json_string) model.save_weights(options[‘file_weight‘]) save(self.generator, "generator") save(self.discriminator, "discriminator") if __name__ == ‘__main__‘: context_encoder = ContextEncoder() context_encoder.train(epochs=2000, batch_size=64, sample_interval=1999)
以上是关于Context_Encoder在mnist的实战的主要内容,如果未能解决你的问题,请参考以下文章
pytorch学习实战第五篇:卷积神经网络实现MNIST手写数字识别
TensorFlow深度学习实战---MNIST数字识别问题