SinGAN一张照片即可生成同样的照片(附简化版代码)

Posted AI信仰者

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SinGAN一张照片即可生成同样的照片(附简化版代码)相关的知识,希望对你有一定的参考价值。

1、摘要

本文主要讲解:SinGAN-一张照片即可生成一模一样的照片(附简化版代码)
主要思路:

  1. 先由一个Z_N输入到G_N的生成器得到生成图像(这一步是单纯由噪声生成,其他生成器的输入都是由随机噪声图像z_n和上一层生成的 上采样到当前生成器尺寸组成)。
  2. 接着利用生成图像的图像块(每一层图像块的大小不一样,按照由粗糙到精细、由大到小)和当前层的图像块(由训练数据下采样得到)放入判别器中进行判断,直到两者不能被判别器区分。
  3. 通过这种一层一层、由下往上的训练过程,得到最终的结果。

2、相关技术

SinGAN架构
一种基于层级的patch-GAN模型(Markovian discriminator)。如下图所示,模型的每个部分负责输入图像的不同尺度捕获图像块分布。这种层级GAN模型感受野小和有限的功能,可以防止网络记住整图的信息。虽然类似的网络结构被应用过,但这是首次应用在一张图像的内部学习上。

模型是由金字塔形式大小的生成器 组成,训练数据 也是金字塔形式大小组成,训练数据是由一个 因子控制,一些r>0。根据每层 的图像块分布,相应层的生成器 产生真实的图像实例。然后通过对抗学习,判别器 通过对生成器 产生的图像块(生成图像的某一部分)进行判别,达到相对较好的状态(以目前来说达不到最终的纳什均衡点),最后完成训练过程。

从刚刚的图中我们可以看到,每个尺度注入噪声后,先由粗糙的尺度开始生成图像,然后按照相应的顺序传递到相对应的生成器,最终生成精细的尺度;某一层的所有生成器和判别器有着相同的感受野,随着由下往上的生成过程,因此可以捕获尺度减小的结构信息。

3、完整代码和步骤

算法训练的效果如此视频:

SinGAN训练过程

主运行程序入口

import os
from SinGAN.run_train import functions
from SinGAN.run_train.manipulate import SinGAN_generate
from SinGAN.run_train.training import train
from SinGAN.run_train.config import get_arguments

if __name__ == '__main__':
    parser = get_arguments()
    parser.add_argument('--input_dir', help='input image dir', default='../Input/Images')
    parser.add_argument('--input_name', help='input image name', default='food.jpg')
    parser.add_argument('--mode', help='task to be done', default='train')
    opt = parser.parse_args()
    #
    opt = functions.post_config(opt)
    Gs = []
    Zs = []
    reals = []
    NoiseAmp = []
    dir2save = functions.generate_dir2save(opt)

    if (os.path.exists(dir2save)):
        print('trained model already exist')
    else:
        try:
            os.makedirs(dir2save)
        except OSError:
            pass
        # 将图片读取成torch版的数据
        real = functions.read_image(opt)
        # 将图片适配尺寸
        functions.adjust_scales2image(real, opt)
        # 开始训练模型 opt 手动输入的参数
        train(opt, Gs, Zs, reals, NoiseAmp)
        # 根据模型生成图片  生成具有任意大小和比例的新图像
        SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)

training.py

	import os
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import math
import matplotlib.pyplot as plt

from SinGAN.run_train import functions, models
from SinGAN.run_train.imresize import imresize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    # 计算局部权重 调整大小
    real = imresize(real_, opt.scale1, opt)
    # 创造真实图片的锥体
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0
    # 全卷积GANs组成的金字塔
    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            # 加载训练好的模型
            G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)
        # 是否固定部分参数进行网络训练
        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return


def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None):
    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  # +(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    #     对Tensor使用0进行边界填充
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)
    # 返回一个大小为fill_value的张量
    z_opt = torch.full(fixed_noise.shape, 0, device=device)
    z_opt = m_noise(z_opt)
    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))
    # 按需调整学习率
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []
    # 它是从噪声生成图像的
    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        # Dsteps 'Discriminator inner steps',default=3
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(device)
            # D_real_map = output.detach()
            errD_real = -output.mean()  # -a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: 最大化 D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            D_fake_map = output.detach()
            errG = -output.mean()
            # errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1)
            # plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            # plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            # plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            # plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            # plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            # plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG


def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):
    G_z = in_s
    if len(Gs) > 0:
        if mode == 'rand'以上是关于SinGAN一张照片即可生成同样的照片(附简化版代码)的主要内容,如果未能解决你的问题,请参考以下文章

一张照片,AI生成抽象画(CLIPasso项目安装使用) | 机器学习系列

一条链接获取你的照片附源码

NAS使用心得使用Synology Photos管理照片

一条链接获取你的照片附源码

一张照片,AI生成抽象画(CLIPasso项目安装使用) | 机器学习

姐姐带我玩转java设计模式(内附照片)- 代理模式