Keras搭建CycleGAN

Posted Paul-Huang

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Keras搭建CycleGAN相关的知识,希望对你有一定的参考价值。

Keras搭建CycleGAN

1. 原理

参考:CycleGAN原理

2. 数据准备

2.1 数据下载

  1. 斑马to黄种马的数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
  2. 苹果to橘子数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
  3. 画作to照片数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/monet2photo.zip
  4. 地图数据集下载:
    https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/maps.zip

2.2 set_session设置

  • 由于set_session在最新版本中已经不存在,所以需要在头文件中添加
    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    sess = tf.compat.v1.Session()
    K.set_session(sess)
    
  • 在初始设置中需要添加:
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.experimental.set_virtual_device_configuration(gpus[0],
         [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=6000)])  # 数值根据显卡内存设定
    

2.3 Tensorflow/Keras 指定CPU运行

2.3.1 全局配置

运行TensorFlow代码时候常出现OOM(Out of Memory)的错误,原因是batch_size设置得太大导致显存不足。如果想让代码仅仅运行在CPU下,可在原代码中加入如下代码:

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

注:上述代码一定要放在import tensorflow或keras等之前,否则不起作用。

2.3.2 tf配置

通过tensorflow配置指定到cpu上运行

with tf.device('/cpu:0'):
	xxx

或者

config = tf.ConfigProto(device_count = {'CPU': 4}) # 分配cpu个数
with tf.Session(config=config) as sess:
	xxx

2.4 keras_contrib库的Windows安装

参考:好像还挺好玩的GAN7——CycleGAN实现图像风格转换

3. 网络构建

3.1. Generator

  • 生成网络的目标是:输入一张图片,转化成自己期望的风格的那张图片。
  • 生成器由三部分组成: 编 码 器 \\color{red}编码器 转 换 器 \\color{red}转换器 解 码 器 \\color{red}解码器 。(也可以用U-net网络)
  • 建立一个build_generator.py文件
# author: HQR
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import layers
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization


def residual_block(input_layer, kernel_size, filter_num, block):
    # 残差网络的函数
    con_name_base = 'res' + block + '_branch'
    in_name_base = 'in' + block + '_branch'
    # 第一层
    x1 = ZeroPadding2D(padding=(1, 1))(input_layer)
    x1 = Conv2D(filters=filter_num, kernel_size=kernel_size, name=con_name_base + '2a')(x1)
    x1 = InstanceNormalization(axis=3, name=in_name_base + '2a')(x1)
    # 第二层
    x2 = ZeroPadding2D(padding=(1, 1))(x1)
    x2 = Conv2D(filters=filter_num, kernel_size=kernel_size, name=con_name_base + '2c')(x2)
    x2 = InstanceNormalization(axis=3, name=in_name_base + '2c')(x2)
    # 残差
    x = layers.add([x2, input_layer])
    x = Activation('relu')(x)
    return x


def encoded(layer_input, filters, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=False):
    if upsampling2d:
        layer_input = UpSampling2D((2, 2))(layer_input)

    x = ZeroPadding2D(padding=pad_size)(layer_input)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
    return x


def build_generator(input_height, input_width, channel):

    img_input = Input(shape=(input_height, input_width, channel))
    # 第一步:编码
    # 128,128,3 ->  128,128,64
    g1 = encoded(img_input, filters=64, pad_size=(3, 3), kernel_size=(7, 7), strides=1)
    # 128,128,64 -> 64,64,128
    g1 = encoded(g1, filters=128, pad_size=(1, 1), kernel_size=(3, 3), strides=2)
    # 64,64,128 -> 32,32,256
    g1 = encoded(g1, filters=256, pad_size=(1, 1), kernel_size=(3, 3), strides=2)

    # 第二步: 转换器,残差网络
    for i in range(9):
        g1 = residual_block(g1, kernel_size=(3, 3), filter_num=256, block=str(i))

    # 第三步: 解码器
    # 32,32,256 -> 64,64,128
    g3 = encoded(g1, filters=128, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=True)
    # 64,64,128 -> 64,64,128 -> 128,128,64
    g3 = encoded(g3, filters=64, pad_size=(1, 1), kernel_size=(3, 3), strides=1, upsampling2d=True)
    # 128,128,64 -> 128,128,3
    g3 = ZeroPadding2D(padding=(3, 3))(g3)
    img_output = Conv2D(channel, kernel_size=(7, 7), activation='tanh')(g3)

    return Model(img_input, img_output)

3.2 Discriminator

  • 判别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。
  • 判别器本身就属于卷积网络,需要从图像中提取特征;然后是确定这些特征是否属于该特定类别,使用一个产生一维输出的卷积层来完成这个任务。
  • Dicriminator的训练的loss函数使用的是LSGAN中所提到 均 方 差 \\color{red}均方差 ,这种loss可以提高假图像的精度。
  • 最后卷积完后的shape为(8,8,1),利用了patch_GAN
  • 建立一个build_discriminator.py文件
# author: HQR
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization


def build_discriminator(input_height, input_width, channel):

    def conv2d(layer_input, filters, f_size=4, nomalization=True):
        d = Conv2D(filters=filters, kernel_size=f_size, strides=2, padding="same")(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if nomalization:
            d = InstanceNormalization()(d)
        return d

    img_input = Input(shape=(input_height, input_width, channel))
    # 128,128,3 -> 64,64,64
    d1 = conv2d(img_input, 64, nomalization=False)
    # 64,64,64 -> 32,32,128
    d2 = conv2d(d1, 128)
    # 32,32,128 -> 16,16,256
    d3 = conv2d(d2, 256)
    # 16,16,256 -> 8,8,512
    d4 = conv2d(d3, 512)
    # 对每个像素点判断是否有效
    # 8,8,512 -> 8,8,1
    validity = Conv2D(filters=1, kernel_size=3, strides=1, padding="same")(d4)

    return Model(img_input, validity)

3.3 数据加载

  • 由于是要用2回load_batch,所以此处不用return,而使用yield
  • 建立一个data_loader.py文件
# author:HQR
import imageio
from skimage.transform import resize
from glob import glob
import numpy as np


class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_image = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_image:
            img = self.imread(img_path)
            if not is_testing:
                img = resize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.
        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))
        # 选择batch
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size
        # 选择数据
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)
        # 选择batch数据
        for i in range(self.n_batches - 1):
            batch_A = path_A[i*batch_size: (i+1)*batch_size]
            batch_B = path_B[i*batch_size: (i+1)*batch_size]
            imgs_A, imgs_B = [], []
            # zip打包成元组处理
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = resize(img_A, self.img_res)
                img_B = resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                    img_A = np.fliplr(img_A)
                    img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = resize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.float)

3.4 训练

  1. 初始化
    1. 创建两个生成模型,一个用于从图片风格A转换成图片风格B,一个用于从图片风格B转换成图片风格A。
    2. 创建两个判别模型,分别用于风格A图片的真伪判断和风格B图片的真伪判断。
    3. 判别模型的训练所用的损失函数与LSGAN相同,通过判断是否正确进行训练。
  2. 损失设定
    损失有以下6种:参考:CycleGAN原理
from __future__ import print_function, division
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from build_generator import *
from bulid_discriminator import *
from data_loader import *
from tensorflow.python.keras import backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os
# 由于GPU总是爆显存,关闭GPU,用CPU进行操作
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# 设置set_session,与GPU有关
sess = tf.compat.v1.Session()
K.set_session(sess)

class CycleGAN():
    def __init__(self):
        # 设置GPU,防止内存爆
        # config = tf.compat.v1.ConfigProto()
        # config.gpu_options.allocator_type = 'BFC'  # A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc.
        # config.gpu_options.per_process_gpu_memory_fraction = 0.8
        # config.gpu_options.allow_growth = True
        # K.set_session(tf.compat.v1.Session(config=config))
        # 输入图像大小128*128*3
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # 载入数据
        self.dataset_name = 'horse2zebra'
        self.data_loader = DataLoader(dataset_name= self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))
        # Calculate output shape of D (PatchGAN)
        # 因为Discriminator 引用了 PatchGAN 的思想
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # 设置参数
        # Loss weights
        self.lambda_cycle = 10.0  # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle  # Identity loss
        # 优化器参数
        optimizer = Adam(0.0002, 0.5)

        # -------------------------#
        #   建立判别网络
        # -------------------------#
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.d_A.summary()

        # -------------------------#
        #   建立判别网络
        # -------------------------#
        # 创建生成模型
        self.g_A2B = self.build_generator()
        self.g_B2A = self.build_generator()
        self.g_A2B.summary()

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # 生成假图片
        fake_B = self.g_A2B(img_A)
        fake_A = self.g_B2A(img_B)
        # 生成重建图片(reconstruction image)
        recon_A = self.g_B2A(fake_B)
        recon_B = self.g_A2B(fake_A)
        # 生成identity图片
        id_A = self.g_B2A(img_A)
        id_B = self.g_A2B(img_B)

        # -------------------------#
        #   将生成模型和判别模型结合,生成模型训练时候,训练时候不训练判别模型
        # -------------------------#
        self.d_A.trainable = False
        self.d_B.trainable = False
        # 评价是否为真
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)
        # 训练
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[valid_A, valid_B,
                                       recon_A, recon_B,
                                       id_A, id_B])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'手把手写深度学习(13):使用CycleGAN将苹果转换成橘子

手把手写深度学习(13):使用CycleGAN将苹果转换成橘子

CycleGAN的pytorch代码实现(代码详细注释)

[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解

深度学习笔记_Keras六步法搭建网络

CycleGAN算法笔记