图像超分辨率重构实战

Posted 有理想的打工人

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了图像超分辨率重构实战相关的知识,希望对你有一定的参考价值。

低分辨率图像重建


今天我们来介绍利用对抗生成网络(GAN)对低分辨率图像进行重构的介绍。再开始今天的任务之前,给大家强调一下,我们需要使用1.x.x版本的tensorflow和tensorlayer,我是用的是3.6版本的python,3.4.1.15版本的opencv以及1.8.0版本的tensorflow和tensorlayer。另外还有其他的一些模块需要安装,直接按照错误提示安装即可。

任务总览

分辨率在图片中的直接反应就是图像的大小,分辨率越高,图像的初始大小越大。如果将不同分辨率的图像放缩到同样的大小,分辨率低的图像会更模糊。超分辨率重构就是将分辨率低的图片重构成清晰的高分辨率图像:

所需要用到的网络结构图为:

数据加载与配置

这个部分对应着生成网络和判别网络的input部分的初始化。
首先需要大家下载srgan任务,打开config文件,我们主要的参数都将在这个文件中进行修改:

from easydict import EasyDict as edict
import json
config = edict()
config.TRAIN = edict()

## Adam
# batch设置过大有可能会引发内存不足的报错
config.TRAIN.batch_size = 4 # 可以适当调整
config.TRAIN.lr_init = 1e-4
config.TRAIN.beta1 = 0.9
## 初始化生成器
config.TRAIN.n_epoch_init = 100
## 判别器学习 (SRGAN)
config.TRAIN.n_epoch = 2000
config.TRAIN.lr_decay = 0.1
config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)

# 训练集路径指定
config.TRAIN.hr_img_path = 'E:\\srgan\\srdata\\srdata\\DIV2K_train_HR'
config.TRAIN.lr_img_path = 'E:\\srgan\\srdata\\srdata\\DIV2K_train_LR_bicubic\\X4'

config.VALID = edict()
# 测试集路径制定
config.VALID.hr_img_path = 'E:\\srgan\\srdata\\srdata\\DIV2K_valid_HR'
config.VALID.lr_img_path = 'E:\\srgan\\srdata\\srdata\\DIV2K_valid_LR_bicubic\\X4'

def log_config(filename, cfg):
    with open(filename, 'w') as f:
        f.write("================================================\\n")
        f.write(json.dumps(cfg, indent=4))
        f.write("\\n================================================\\n")

这里改好之后,我们需要对一些main.py文件里的函数进行一些设置,比如传递进batch_size,学习率,epoch等,同时要指定好生成的图像以及模型等文件的存储位置,之后把再图像读取进来:

import os
import time
import pickle, random
import numpy as np
import logging, scipy

import tensorflow as tf
import tensorlayer as tl
from model import SRGAN_g, SRGAN_d, Vgg19_simple_api
from utils import *
from config import config, log_config

## Adam
batch_size = config.TRAIN.batch_size # 4
lr_init = config.TRAIN.lr_init # 1e-4
beta1 = config.TRAIN.beta1 # 0.9
## 初始化生成器
n_epoch_init = config.TRAIN.n_epoch_init # 100
## 判别器学习(SRGAN)
n_epoch = config.TRAIN.n_epoch # 2000
lr_decay = config.TRAIN.lr_decay # 0.1
decay_every = config.TRAIN.decay_every # 1000

ni = int(np.sqrt(batch_size))


def train():
    ## 创建文件夹保存结果图像和训练模型
    save_dir_ginit = "samples/_ginit".format(tl.global_flag['mode'])
    save_dir_gan = "samples/_gan".format(tl.global_flag['mode'])
    tl.files.exists_or_mkdir(save_dir_ginit)
    tl.files.exists_or_mkdir(save_dir_gan)
    checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
    tl.files.exists_or_mkdir(checkpoint_dir)

    # load_file_list可以把所有的文件都加载进来
    # path指定文件夹的路径
    #  regx='.*.png'代表读取所有.png的文件
    train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:800] # 如果出现memory error可以这样操作减少一次读取的数据量
    train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))       # 不加切片读取也是可以的,但一定要注意传入的低分高分图像数量要匹配
                                                                                                                             # 读取全部的内容花费时间较长
    valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
    valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

    ## 如果计算机内存够大,可以加在全部内容

    # n_threads可以当成多线程,这里意思是每8张一组一并处理
    train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=8)

## 设置生成器、判别器和特征提取模块的输入内容
    # 制作生成器和判别器的输入数据
    t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
    # 判别器接收的原始高分辨图像
    t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')
    # vgg特征提取模块初始化设置
    t_target_image_224 = tf.image.resize_images(
        t_target_image, size=[224, 224], method=0, # 剪切成对应的大小
        align_corners=False)  
    t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)

这样一来,我们就完成了数据的加载和小部分参数的配置。接下来我们就需要在main.py文件中继续调整生成模块、判别模块、特征提取、损失函数设置和测试模块。

模型设置

以上我们已经完成了读取文件夹内的图像内容的任务,接下来就需要用生成器和判别器分别处理各自的输入内容了。源码中生成器和判别器的具体操作是在model.py文件中执行的,main.py只是负责调用这个模块。因此我们先讲解model中的内容。首先说生成器:

生成器所需要用到的卷积和残差模块,以及对应结果加和处理都需要在这里进行设置:

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import time
import os

# 生成网络
def SRGAN_g(t_image, is_train=False, reuse=False):
    w_init = tf.random_normal_initializer(stddev=0.02)
    b_init = None  # tf.constant_initializer(value=0.0)
    g_init = tf.random_normal_initializer(1., 0.02)
    with tf.variable_scope("SRGAN_g", reuse=reuse) as vs:
        # 输入层,内容+名字
        n = InputLayer(t_image, name='in')
        # 进行卷积(初始化)
        n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
        temp = n
        
        # 设置16个残差模块
        for i in range(16):
            nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)
            nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
            nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)
            nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
            nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)
            n = nn

        # 残差信息整合
        # 对应网络示意图中的skip connection步骤
        n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')
        n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')
        # 把最开始的结果(temp)加到当前的结果当中
        n = ElementwiseLayer([n, temp], tf.add, name='add3')

        # 重构出图
        n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
        n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')

        n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')
        n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')

        n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
        return n

对于判别器,也要在model中进行设置:

def SRGAN_d(input_images, is_train=True, reuse=False): # reuse指定为True意味着输入的图像是从原始数据集中取到的,
                                                       #           False意味着图像是生成器生成的
    # 参数的初始化指定
    w_init = tf.random_normal_initializer(stddev=0.02)
    b_init = None  # tf.constant_initializer(value=0.0)
    gamma_init = tf.random_normal_initializer(1., 0.02)
    df_dim = 64
    lrelu = lambda x: tl.act.lrelu(x, 0.2)
    # 基础的判别网络
    with tf.variable_scope("SRGAN_d", reuse=reuse):
        tl.layers.set_name_reuse(reuse)
        net_in = InputLayer(input_images, name='input/images')
        net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')

        net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')
        net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn')
        net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')
        net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn')
        net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')
        net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn')
        net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')
        net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn')
        net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')
        net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn')
        net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')
        net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn')
        net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')
        net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn')

        net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c')
        net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn')
        net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')
        net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2')
        net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')
        net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3')
        net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add')
        net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)

        net_ho = FlattenLayer(net_h8, name='ho/flatten') # 池化
        net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')
        logits = net_ho.outputs
        net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)

    return net_ho, logits

如果上述内容中有不懂的参数,可以查询文档
还有,我们需要把特征提取模块(VGG)加进来,这个模块具体的作用会在损失函数里具体介绍,我们这里只需要知道vgg会帮我们提取生成图像和原始高清图像做特征比对,我们把它也写到model里:

def Vgg19_simple_api(rgb, reuse):
    # 减均值
    VGG_MEAN = [103.939, 116.779, 123.68]
    with tf.variable_scope("VGG19", reuse=reuse) as vs:
        start_time = time.time()
        print(以上是关于图像超分辨率重构实战的主要内容,如果未能解决你的问题,请参考以下文章

深度原理与框架-图像超分辨重构-tensorlayer

深度学习工程师必看:更简单的超分辨重构方法拿走不谢

SRCNN 图像超分辨率重建(tf2)

[转]图像超分辨率重建简介

一文教你使用图像超分辨率模型(针对小白,不训练)

一文教你使用图像超分辨率模型(针对小白,不训练)