『TensorFlow』以GAN为例的神经网络类范式
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了『TensorFlow』以GAN为例的神经网络类范式相关的知识,希望对你有一定的参考价值。
1、导入包:
import os import time import math from glob import glob from PIL import Image import tensorflow as tf import numpy as np import ops # 层函数封装包 import utils # 其他辅助函数
2、简单的临时辅助函数:
def conv_out_size_same(size, stride): # 对浮点数向上取整(大于f的最小整数) return int(math.ceil(float(size) / float(stride)))
3、声明类&初始化类:
示例没有使用到,实际上一般类属性也会用到
类属性&__init__初始化:用于接收参数生成低层次的属性值,数据读取或者数据名列表一般也会放在__init__中
class DCGAN(): def __init__(self, sess, input_height=108, input_width=108, crop=True, batch_size=64, sample_num=64, output_height=64, output_width=64, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name=‘default‘, input_fname_pattern=‘*.jpg‘, checkpoint_dir=None, sample_dir=None): """ Args: sess: TensorFlow session batch_size: The size of batch. Should be specified before training. z_dim: (optional) Dimension of dim for Z. [100] gf_dim: (optional) Dimension of gen filters in first conv layer. [64] df_dim: (optional) Dimension of discrim filters in first conv layer. [64] gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] """ self.sess = sess self.batch_size = batch_size self.sample_num = sample_num # crop输入输出尺寸 # crop为True则output尺寸为网络输入尺寸 # crop为False则input直接进入网络输入层 self.crop = crop self.input_height = input_height self.input_width = input_width self.output_height = output_height self.output_width = output_width self.z_dim = z_dim self.gf_dim = gf_dim self.df_dim = df_dim self.dfc_dim = dfc_dim self.gfc_dim = gfc_dim self.g_bn0 = ops.batch_norm(name=‘g_bn0‘) self.g_bn1 = ops.batch_norm(name=‘g_bn1‘) self.g_bn2 = ops.batch_norm(name=‘g_bn2‘) self.g_bn3 = ops.batch_norm(name=‘g_bn3‘) self.d_bn1 = ops.batch_norm(name=‘d_bn1‘) self.d_bn2 = ops.batch_norm(name=‘d_bn2‘) self.d_bn3 = ops.batch_norm(name=‘d_bn3‘) ‘‘‘读取数据‘‘‘ self.dataset_name = dataset_name self.input_fname_pattern = input_fname_pattern self.checkpoint_dir = checkpoint_dir self.data = glob(os.path.join(‘./data‘, self.dataset_name, self.input_fname_pattern)) # 载入所有图片 ‘‘‘读取一张图片判断通道数目‘‘‘ imreadImg = np.asarray(Image.open(self.data[0])) if len(imreadImg.shape) >= 3: self.c_dim = imreadImg.shape[-1] else: self.c_dim = 1 self.grayscale = (self.c_dim == 1)
4、网络结构生成:
由于GAN的特殊性,被拆分了build_model(self)作为主干,discriminator(self,image,reuse=False)和generator(self,z)作为模组,这一过程包含了由数据进入网络到loss函数计算的整个流程
def build_model(self): if self.crop: image_dims = [self.output_height, self.output_width, self.c_dim] else: image_dims = [self.input_height, self.input_width, self.c_dim] ‘‘‘数据输入层‘‘‘ self.input_layer = tf.placeholder(tf.float32, [self.batch_size].extend(image_dims), name=‘input_layer‘) inputs = self.input_layer self.z = tf.placeholder(tf.float32, [None, self.z_dim], name=‘z‘) self.z_sum = tf.summary.histogram(‘z‘, self.z) ‘‘‘主要计算节点‘‘‘ # 生成 self.G = self.generator(self.z) self.D, self.D_logits = self.discriminator(inputs, reuse=False) self.sampler = self.sampler(self.z) self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) # 记录 self.G_sum = tf.summary.image(‘G‘, self.G) self.D_sum = tf.summary.histogram(‘D‘, self.D) self.D__sum = tf.summary.histogram(‘D_‘, self.D_) ‘‘‘损失函数‘‘‘ # 构建 self.d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_))) self.g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_))) self.d_loss = self.d_loss_real + self.d_loss_fake # 记录 self.d_loss_real_sum = tf.Summary.scalar("d_loss_real",self.d_loss_real) self.d_loss_fake_sum = tf.Summary.scalar("d_loss_fake",self.d_loss_fake) self.g_loss_sum = tf.Summary.scalar("g_loss",self.g_loss) self.d_loss_sum = tf.Summary.scalar("d_loss",self.d_loss) # 训练参数分离 t_vars = tf.trainable_variables() self.d_vars = [var for var in t_vars if ‘d_‘ in var.name] self.g_vars = [var for var in t_vars if ‘g_‘ in var.name] # 保存器类 self.saver = tf.train.Saver() def discriminator(self,image,reuse=False): with tf.variable_scope(‘discriminator‘, reuse=reuse): h0 = ops.lrelu(ops.conv2d(image,self.df_dim,name=‘d_h0_conv‘)) h1 = ops.lrelu(self.d_bn1(ops.conv2d(h0,self.df_dim * 2,name=‘d_h1_conv‘))) h2 = ops.lrelu(self.d_bn2(ops.conv2d(h1,self.df_dim * 4,name=‘d_h2_conv‘))) h3 = ops.lrelu(self.d_bn3(ops.conv2d(h2,self.df_dim * 8,name=‘d_h3_conv‘))) h4 = ops.linear(tf.reshape(h3,[self.batch_size,-1]),1,‘d_h4_lin‘) return tf.nn.sigmoid(h4),h4 def generator(self,z): with tf.variable_scope(‘generator‘): s_h, s_w = self.output_height, self.output_width # 生成图片大小 s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2) s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2) s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2) s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2) # batch_size不变,h、w每层扩大一倍,c每层缩小一半 # 线性层 self.z_,self.h0_w,self.h0_b = ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘,with_w=True) self.h0 = tf.reshape(self.z_,[-1,s_h16,s_w16,self.gf_dim * 8]) h0 = tf.nn.relu(self.g_bn0(self.h0)) # 转置卷积层 self.h1,self.h1_w,self.h1_b = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘,with_w=True) h1 = tf.nn.relu(self.g_bn1(self.h1)) h2,self.h2_w,self.h2_b = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘,with_w=True) h2 = tf.nn.relu(self.g_bn2(h2)) h3,self.h3_w,self.h3_b = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘,with_w=True) h3 = tf.nn.relu(self.g_bn3(h3)) h4,self.h4_w,self.h4_b = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘,with_w=True) return tf.nn.tanh(h4)
5、预测部分:
一般网络用于predict标签的部分,对应到GAN就是生成仿真图片的位置,这里是不参与训练的
def sampler(self,z): # 和生成器完全相同的结构且共享了变量,知识在正则化处is_training为False,这影响了滑动平均使用的两个部分 with tf.variable_scope("generator") as scope: scope.reuse_variables() s_h,s_w = self.output_height,self.output_width s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2) s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2) s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2) s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2) h0 = tf.reshape(ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘), [-1,s_h16,s_w16,self.gf_dim * 8]) h0 = tf.nn.relu(self.g_bn0(h0,train=False)) h1 = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘) h1 = tf.nn.relu(self.g_bn1(h1,train=False)) h2 = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘) h2 = tf.nn.relu(self.g_bn2(h2,train=False)) h3 = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘) h3 = tf.nn.relu(self.g_bn3(h3,train=False)) h4 = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘)
6、训练部分:
超级麻烦的部分,
- 构建优化器
- 载入上次训练的结果
- 迭代训练
- 读取batch_size数据
- feed进网络训练
- 输出中间参量辅助查看
- 保存模型
def train(self,config): # 辨别器优化(总) d_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1) .minimize(self.d_loss,var_list=self.d_vars) # 生成器优化 g_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1) .minimize(self.g_loss,var_list=self.g_vars) tf.global_variables_initializer().run() # 记录各个值迭代的变化 self.g_sum = tf.Summary.merge([self.z_sum,self.D__sum, self.G_sum,self.d_loss_fake_sum,self.g_loss_sum]) self.d_sum = tf.summary.merge([self.z_sum,self.d_sum,self.d_loss_real_sum,self.d_loss_sum]) self.writer = tf.Summary.Writer("./logs",self.sess.graph) # 读取sample_num张图片 sample_files = self.data[0:self.sample_num] sample = [utils.get_image(sample_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, crop=self.crop) for sample_file in sample_files] sample_inputs = np.array(sample).astype(np.float32) sample_z = np.random.uniform(-1,1,size=(self.sample_num,self.z_dim)) counter = 1 start_time = time.time() could_load,checkpoint_counter = self.load(self.checkpoint_dir) # 载入model继续训练 if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS") else: print(" [!] Load failed...") for epoch in range(config.epoch): self.data = glob(os.path.join( "./data",config.dataset,self.input_fname_pattern)) batch_idxs = min(len(self.data),config.train_size) // config.batch_size for idx in range(0,batch_idxs): # 读取batch图片x batch_files = self.data[idx * config.batch_size:(idx + 1) * config.batch_size] batch = [ utils.get_image(batch_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, crop=self.crop) for batch_file in batch_files] batch_images = np.array(batch).astype(np.float32) # 生成噪声z batch_z = np.random.uniform(-1,1,[config.batch_size,self.z_dim]) .astype(np.float32) # Update D network _,summary_str = self.sess.run([d_optim,self.d_sum], feed_dict={self.input_layer: batch_images,self.z: batch_z}) self.writer.add_summary(summary_str,counter) # Update G network _,summary_str = self.sess.run([g_optim,self.g_sum], feed_dict={self.z: batch_z}) self.writer.add_summary(summary_str,counter) # 书写器书写的并不是一般意义上的记录而是普通的标量值 # Update G network # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _,summary_str = self.sess.run([g_optim,self.g_sum], feed_dict={self.z: batch_z}) self.writer.add_summary(summary_str,counter) # run损失值 errD_fake = self.d_loss_fake.eval({self.z: batch_z}) errD_real = self.d_loss_real.eval({self.input_layer: batch_images}) errG = self.g_loss.eval({self.z: batch_z}) counter += 1 print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch,idx,batch_idxs, time.time() - start_time,errD_fake + errD_real,errG)) if np.mod(counter,100) == 1: try: samples,d_loss,g_loss = self.sess.run( [self.sampler,self.d_loss,self.g_loss], feed_dict={ self.z: sample_z, self.input_layer: sample_inputs, }, ) utils.save_images(samples,utils.image_manifold_size(samples.shape[0]), ‘./{}/train_{:02d}_{:04d}.png‘.format(config.sample_dir,epoch,idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss,g_loss)) except: print("one pic error!...") if np.mod(counter,500) == 2: self.save(config.checkpoint_dir,counter)
保存&载入模型的一个demo
个人感觉功能有点臃肿,不过还是很值得借鉴的,
比如使用装饰器把函数隐藏成属性这个我就感觉很没必要,毕竟都是自家内部调用... ...
检查文件夹时的固定搭配这个就很不错:
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
作者为了跑不同的数据集在文件名归类上下了一番功夫,所以load模块比较复杂,所以适当的多给了一些注释
‘‘‘模型保存&载入‘‘‘ # checkpoint_dir/datasetname_batchsize_outputheight_outputwidth/模型 @property def model_dir(self): return "{}_{}_{}_{}".format( self.dataset_name,self.batch_size, self.output_height,self.output_width) def save(self,checkpoint_dir,step): model_name = "DCGAN.model" checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir,model_name), global_step=step) def load(self,checkpoint_dir): import re print(" [*] Reading checkpoints...") checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir) # 合并模型根路径和数据集路径 ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # 模型保存文件夹->最新模型文件名 if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # 提取无路径模型文件名,感觉没有必要,checkpoint保存的名字本身就是不带路径的 self.saver.restore(self.sess,os.path.join(checkpoint_dir,ckpt_name)) # 载入参数 counter = int(next(re.finditer("(\\d+)",ckpt_name)).group(0)) # 提取训练轮数 print(" [*] Success to read {}".format(ckpt_name)) return True,counter else: print(" [*] Failed to find a checkpoint") return False,0
附:脚本调用
import os import pprint import numpy as np import tensorflow as tf from model import DCGAN # 接收命令行参数分三步 flags = tf.app.flags flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") FLAGS = flags.FLAGS # 必须带参数,否则:‘TypeError: main() takes no arguments (1 given)‘; # main的参数名随意定义,无要求 def main(_): # pprint模块,更美观的显示数据结构 pp = pprint.PrettyPrinter() pp.pprint(flags.FLAGS.__flags) if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) run_config = tf.ConfigProto() # TensorFlow占用gpu资源的默认方式异常贪婪,这里修改为按需求申请 run_config.gpu_options.allow_growth = True # 下面的是按比例申请 # run_config.gpu_options.per_process_gpu_memory_fraction=0.333 with tf.Session(config=run_config) as sess: dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) if FLAGS.train: dcgan.train(FLAGS) else: if not dcgan.load(FLAGS.checkpoint_dir)[0]: raise Exception("[!] Train a model first, then run test mode") if __name__==‘__main__‘: tf.app.run()
预测部分没写好,所以没加上来,但是这不妨碍理解思路
值得一提的是dcgan.train(FLAGS),这里直接传入了FLAGS,对应内部train函数接收参数config,{config.参数名}这样的调用方法十分方便,这也有助于理解脚本化TF程序的便利之处『TensorFlow』脚本化使用方法。
以上是关于『TensorFlow』以GAN为例的神经网络类范式的主要内容,如果未能解决你的问题,请参考以下文章
不要怂 就是"干"如何用 TensorFlow 实现生成式对抗网络(GAN)