『TensorFlow』以GAN为例的神经网络类范式

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了『TensorFlow』以GAN为例的神经网络类范式相关的知识,希望对你有一定的参考价值。

1、导入包:

import os
import time
import math
from glob import glob
from PIL import Image
import tensorflow as tf
import numpy as np

import ops                    # 层函数封装包
import utils                  # 其他辅助函数

2、简单的临时辅助函数:

def conv_out_size_same(size, stride):
    # 对浮点数向上取整(大于f的最小整数)
    return int(math.ceil(float(size) / float(stride)))

3、声明类&初始化类:

示例没有使用到,实际上一般类属性也会用到

类属性&__init__初始化:用于接收参数生成低层次的属性值,数据读取或者数据名列表一般也会放在__init__中

class DCGAN():

    def __init__(self, sess,
                 input_height=108, input_width=108,
                 crop=True, batch_size=64, sample_num=64,
                 output_height=64, output_width=64,
                 z_dim=100, gf_dim=64,
                 df_dim=64, gfc_dim=1024,
                 dfc_dim=1024, c_dim=3,
                 dataset_name=default, input_fname_pattern=*.jpg,
                 checkpoint_dir=None, sample_dir=None):
        """
        Args:
            sess: TensorFlow session
            batch_size: The size of batch. Should be specified before training.
            z_dim: (optional) Dimension of dim for Z. [100]
            gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
            df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
            gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
            dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
            c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
        """
        self.sess = sess
        self.batch_size = batch_size
        self.sample_num = sample_num

        # crop输入输出尺寸
        # crop为True则output尺寸为网络输入尺寸
        # crop为False则input直接进入网络输入层
        self.crop = crop
        self.input_height = input_height
        self.input_width = input_width
        self.output_height = output_height
        self.output_width = output_width

        self.z_dim = z_dim

        self.gf_dim = gf_dim
        self.df_dim = df_dim

        self.dfc_dim = dfc_dim
        self.gfc_dim = gfc_dim

        self.g_bn0 = ops.batch_norm(name=g_bn0)
        self.g_bn1 = ops.batch_norm(name=g_bn1)
        self.g_bn2 = ops.batch_norm(name=g_bn2)
        self.g_bn3 = ops.batch_norm(name=g_bn3)

        self.d_bn1 = ops.batch_norm(name=d_bn1)
        self.d_bn2 = ops.batch_norm(name=d_bn2)
        self.d_bn3 = ops.batch_norm(name=d_bn3)

        ‘‘‘读取数据‘‘‘
        self.dataset_name = dataset_name
        self.input_fname_pattern = input_fname_pattern
        self.checkpoint_dir = checkpoint_dir

        self.data = glob(os.path.join(./data, self.dataset_name, self.input_fname_pattern))  # 载入所有图片

        ‘‘‘读取一张图片判断通道数目‘‘‘
        imreadImg = np.asarray(Image.open(self.data[0]))
        if len(imreadImg.shape) >= 3:
            self.c_dim = imreadImg.shape[-1]
        else:
            self.c_dim = 1

        self.grayscale = (self.c_dim == 1)

4、网络结构生成:

由于GAN的特殊性,被拆分了build_model(self)作为主干,discriminator(self,image,reuse=False)和generator(self,z)作为模组,这一过程包含了由数据进入网络到loss函数计算的整个流程

    def build_model(self):

        if self.crop:
            image_dims = [self.output_height, self.output_width, self.c_dim]
        else:
            image_dims = [self.input_height, self.input_width, self.c_dim]

        ‘‘‘数据输入层‘‘‘
        self.input_layer = tf.placeholder(tf.float32, [self.batch_size].extend(image_dims), name=input_layer)
        inputs = self.input_layer

        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name=z)
        self.z_sum = tf.summary.histogram(z, self.z)

        ‘‘‘主要计算节点‘‘‘
        # 生成
        self.G                  = self.generator(self.z)
        self.D, self.D_logits   = self.discriminator(inputs, reuse=False)
        self.sampler            = self.sampler(self.z)
        self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)

        # 记录
        self.G_sum = tf.summary.image(G, self.G)
        self.D_sum = tf.summary.histogram(D, self.D)
        self.D__sum = tf.summary.histogram(D_, self.D_)

        ‘‘‘损失函数‘‘‘
        # 构建
        self.d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_)))
        self.d_loss = self.d_loss_real + self.d_loss_fake

        # 记录
        self.d_loss_real_sum = tf.Summary.scalar("d_loss_real",self.d_loss_real)
        self.d_loss_fake_sum = tf.Summary.scalar("d_loss_fake",self.d_loss_fake)
        self.g_loss_sum = tf.Summary.scalar("g_loss",self.g_loss)
        self.d_loss_sum = tf.Summary.scalar("d_loss",self.d_loss)

        # 训练参数分离
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if d_ in var.name]
        self.g_vars = [var for var in t_vars if g_ in var.name]

        # 保存器类
        self.saver = tf.train.Saver()

    def discriminator(self,image,reuse=False):
        with tf.variable_scope(discriminator, reuse=reuse):
            h0 = ops.lrelu(ops.conv2d(image,self.df_dim,name=d_h0_conv))
            h1 = ops.lrelu(self.d_bn1(ops.conv2d(h0,self.df_dim * 2,name=d_h1_conv)))
            h2 = ops.lrelu(self.d_bn2(ops.conv2d(h1,self.df_dim * 4,name=d_h2_conv)))
            h3 = ops.lrelu(self.d_bn3(ops.conv2d(h2,self.df_dim * 8,name=d_h3_conv)))
            h4 = ops.linear(tf.reshape(h3,[self.batch_size,-1]),1,d_h4_lin)

        return tf.nn.sigmoid(h4),h4

    def generator(self,z):
        with tf.variable_scope(generator):
            s_h, s_w = self.output_height, self.output_width                        # 生成图片大小
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            # batch_size不变,h、w每层扩大一倍,c每层缩小一半

            # 线性层
            self.z_,self.h0_w,self.h0_b = ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,g_h0_lin,with_w=True)
            self.h0 = tf.reshape(self.z_,[-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(self.h0))

            # 转置卷积层
            self.h1,self.h1_w,self.h1_b = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=g_h1,with_w=True)
            h1 = tf.nn.relu(self.g_bn1(self.h1))

            h2,self.h2_w,self.h2_b = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=g_h2,with_w=True)
            h2 = tf.nn.relu(self.g_bn2(h2))

            h3,self.h3_w,self.h3_b = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=g_h3,with_w=True)
            h3 = tf.nn.relu(self.g_bn3(h3))

            h4,self.h4_w,self.h4_b = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=g_h4,with_w=True)

        return tf.nn.tanh(h4)

5、预测部分:

一般网络用于predict标签的部分,对应到GAN就是生成仿真图片的位置,这里是不参与训练的

    def sampler(self,z):
        # 和生成器完全相同的结构且共享了变量,知识在正则化处is_training为False,这影响了滑动平均使用的两个部分
        with tf.variable_scope("generator") as scope:
            scope.reuse_variables()

            s_h,s_w = self.output_height,self.output_width
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            h0 = tf.reshape(ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,g_h0_lin), [-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(h0,train=False))

            h1 = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=g_h1)
            h1 = tf.nn.relu(self.g_bn1(h1,train=False))

            h2 = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=g_h2)
            h2 = tf.nn.relu(self.g_bn2(h2,train=False))

            h3 = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=g_h3)
            h3 = tf.nn.relu(self.g_bn3(h3,train=False))

            h4 = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=g_h4)

 

6、训练部分:

超级麻烦的部分,

  • 构建优化器
  • 载入上次训练的结果
  • 迭代训练
    • 读取batch_size数据
    • feed进网络训练
    • 输出中间参量辅助查看
    • 保存模型
    def train(self,config):
        # 辨别器优化(总)
        d_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.d_loss,var_list=self.d_vars)
        # 生成器优化
        g_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.g_loss,var_list=self.g_vars)

        tf.global_variables_initializer().run()

        # 记录各个值迭代的变化
        self.g_sum = tf.Summary.merge([self.z_sum,self.D__sum, self.G_sum,self.d_loss_fake_sum,self.g_loss_sum])
        self.d_sum = tf.summary.merge([self.z_sum,self.d_sum,self.d_loss_real_sum,self.d_loss_sum])

        self.writer = tf.Summary.Writer("./logs",self.sess.graph)

        # 读取sample_num张图片
        sample_files = self.data[0:self.sample_num]
        sample = [utils.get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      crop=self.crop) for sample_file in sample_files]
        sample_inputs = np.array(sample).astype(np.float32)
        sample_z = np.random.uniform(-1,1,size=(self.sample_num,self.z_dim))

        counter = 1
        start_time = time.time()
        could_load,checkpoint_counter = self.load(self.checkpoint_dir)

        # 载入model继续训练
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in range(config.epoch):
            self.data = glob(os.path.join(
                "./data",config.dataset,self.input_fname_pattern))
            batch_idxs = min(len(self.data),config.train_size) // config.batch_size
            for idx in range(0,batch_idxs):

                # 读取batch图片x
                batch_files = self.data[idx * config.batch_size:(idx + 1) * config.batch_size]
                batch = [
                    utils.get_image(batch_file,
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              crop=self.crop) for batch_file in batch_files]
                batch_images = np.array(batch).astype(np.float32)

                # 生成噪声z
                batch_z = np.random.uniform(-1,1,[config.batch_size,self.z_dim])                     .astype(np.float32)

                # Update D network
                _,summary_str = self.sess.run([d_optim,self.d_sum],
                                              feed_dict={self.input_layer: batch_images,self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # Update G network
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)                # 书写器书写的并不是一般意义上的记录而是普通的标量值

                # Update G network
                # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # run损失值
                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.input_layer: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"                       % (epoch,idx,batch_idxs,
                         time.time() - start_time,errD_fake + errD_real,errG))
                if np.mod(counter,100) == 1:
                    try:
                        samples,d_loss,g_loss = self.sess.run(
                            [self.sampler,self.d_loss,self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.input_layer: sample_inputs,
                            },
                        )
                        utils.save_images(samples,utils.image_manifold_size(samples.shape[0]),
                                    ./{}/train_{:02d}_{:04d}.png.format(config.sample_dir,epoch,idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss,g_loss))
                    except:
                        print("one pic error!...")
                if np.mod(counter,500) == 2:
                    self.save(config.checkpoint_dir,counter)

 

保存&载入模型的一个demo

个人感觉功能有点臃肿,不过还是很值得借鉴的,

比如使用装饰器把函数隐藏成属性这个我就感觉很没必要,毕竟都是自家内部调用... ...

检查文件夹时的固定搭配这个就很不错:

if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

 

作者为了跑不同的数据集在文件名归类上下了一番功夫,所以load模块比较复杂,所以适当的多给了一些注释

    ‘‘‘模型保存&载入‘‘‘

    # checkpoint_dir/datasetname_batchsize_outputheight_outputwidth/模型
    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
            self.dataset_name,self.batch_size,
            self.output_height,self.output_width)

    def save(self,checkpoint_dir,step):
        model_name = "DCGAN.model"
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir,model_name),
                        global_step=step)

    def load(self,checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)                  # 合并模型根路径和数据集路径
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)                          # 模型保存文件夹->最新模型文件名
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)                  # 提取无路径模型文件名,感觉没有必要,checkpoint保存的名字本身就是不带路径的
            self.saver.restore(self.sess,os.path.join(checkpoint_dir,ckpt_name))      # 载入参数
            counter = int(next(re.finditer("(\\d+)",ckpt_name)).group(0))              # 提取训练轮数
            print(" [*] Success to read {}".format(ckpt_name))
            return True,counter
        else:
            print(" [*] Failed to find a checkpoint")
        return False,0

 

附:脚本调用

import os
import pprint
import numpy as np
import tensorflow as tf

from model import DCGAN


# 接收命令行参数分三步

flags = tf.app.flags

flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")

FLAGS = flags.FLAGS


# 必须带参数,否则:‘TypeError: main() takes no arguments (1 given)‘;
# main的参数名随意定义,无要求
def main(_):
    # pprint模块,更美观的显示数据结构
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)


    run_config = tf.ConfigProto()
    # TensorFlow占用gpu资源的默认方式异常贪婪,这里修改为按需求申请
    run_config.gpu_options.allow_growth = True
    # 下面的是按比例申请
    # run_config.gpu_options.per_process_gpu_memory_fraction=0.333

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            crop=FLAGS.crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            sample_dir=FLAGS.sample_dir)

    if FLAGS.train:
        dcgan.train(FLAGS)
    else:
        if not dcgan.load(FLAGS.checkpoint_dir)[0]:
            raise Exception("[!] Train a model first, then run test mode")

if __name__==__main__:
    tf.app.run()

 

预测部分没写好,所以没加上来,但是这不妨碍理解思路

值得一提的是dcgan.train(FLAGS),这里直接传入了FLAGS,对应内部train函数接收参数config,{config.参数名}这样的调用方法十分方便,这也有助于理解脚本化TF程序的便利之处『TensorFlow』脚本化使用方法

 

以上是关于『TensorFlow』以GAN为例的神经网络类范式的主要内容,如果未能解决你的问题,请参考以下文章

以神经网络使用为例的Matlab和Android混合编程

以神经网络使用为例的Matlab和Android混合编程

不要怂 就是"干"如何用 TensorFlow 实现生成式对抗网络(GAN)

美团云Tensorflow生成对抗网络(Generative Adversarial Networks)实战案例

利用tensorflow训练简单的生成对抗网络GAN

从零开始自己搭建复杂网络(以Tensorflow为例)