[Pytorch系列-74]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - pix2pix网络结构与代码实现详解

Posted 文火冰糖的硅基工坊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-74]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - pix2pix网络结构与代码实现详解相关的知识,希望对你有一定的参考价值。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122101231


目录

第1章 网络的定义

1.1 网络结构

1.2 代码来源

1.3 网络结构代码解读

1.4 输入数据集处理代码解读

1.5 前向运算

第2章 网络的训练

1.1 G生成网络的结构与代码解读

1.2 D判决网络的结构与代码解读

1.3 pix2pix网络整体的优化算法


第1章 网络的定义

1.1 网络结构

 相对于基础型的GAN网络,pix2pix网络,并没有增加新的网络结构,只在基础型的GAN基础上做了如下的优化:

  • 判决网络的输入:增加了输入图片,与输出fake图片一起参与判决
  • 判决网络的输出:不仅仅需要参与判决网络的判决,还需要与样本标签图片进行像素级的比较。

1.2 代码来源

pytorch-CycleGAN-and-pix2pix\\models\\pix2pix_model.py

1.3 网络结构代码解读

   def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
  • train模式下需要定义G + D网络, 而在测试或预测模式下,只需要定义G网络。
  • 只有在训练模式下,才需要定义loss和优化算法。

1.4 输入数据集处理代码解读

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
  • pix2pix网络采用的是成对数据集(paired数据集),这个数据集是把两个图片按照字典的方式组合在一起的,因此需要先把他们分离开来。
  • real_A: 真实的输入图片。
  • real_B: 真实的标签图片 (标签不一定是分类的数值,也可以是一张图片)

1.5 前向运算

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)
  • 前向运算只使用G网络,创造或生成图片。
  • 前向运算的输入是真实图片Real_A,输出是生成图片fake_B, 与生成图片对应的是真实图片Real_B.
  • 如果仅仅是测试训练好的模型,有前向运算就可以了,不需要训练。

第2章 网络的训练

1.1 G生成网络的结构与代码解读

(1)G网络的训练架构

  • 在训练G网络时,需要锁定D网络。

 (2)G网络Loss代码实现

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""

        # First, G(A) should fake the discriminator
        # 组合real_A和fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        
        # 组合后图让锁定后的D网络进行判决
        pred_fake = self.netD(fake_AB)

        # 通过调整G网络,期望能够骗过D网络,即预测值接近True(1)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        
        # Second, G(A) = B
        # 确保生成图片fake_B, 不仅仅能够骗过D网络,还需要与标签图片real_B接近。
        # 反应在代码上,采用的像素点的绝对值差L1 loss来实现的。
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

        # combine loss and calculate gradients
        # 最终的目标组合上述两个loss,优化算法使得组合后的loss最小。
        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        # 反向求导,求G网络所有参数的梯度
        self.loss_G.backward()

1.2 D判决网络的结构与代码解读

(1)D网络的训练架构

 (2)D网络Loss代码实现

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        # we use conditional GANs; 
        # we need to feed both input and output to the discriminator
        # 合并输入图片real_A和生成图片fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1) 

        # 使用D网络进行判决
        pred_fake = self.netD(fake_AB.detach())
        
        # 通过调整D网络参数,需要识别出,该输出图片为"假"
        # 体现在代码上,就是预测结果pred_fake与False(0) 相比
        # 优化算法,使得判断结果为0
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        # 合并真实输入图片real_A与真实的标签图片real_B
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        
        # 使用D网络进行判决
        pred_real = self.netD(real_AB)
        
        # 通过调整D网络参数,需要识别出,该输出图片为"真"
        # 体现在代码上,就是预测结果pred_fake与True(1) 相比
        # 优化算法,使得判断结果为1
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # combine loss and calculate gradients
        # 组合上述两种loss,并求平均
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        # 反向求导,求G网络所有参数的梯度
        self.loss_D.backward()

1.3 pix2pix网络整体的优化算法

    def optimize_parameters(self):
        
        # 使用real_A进行前向运算,生出图片fake_B
        self.forward()                   # compute fake images: G(A)

        
        # update D
        # 使能D网络的梯度迭代
        self.set_requires_grad(self.netD, True)  # enable backprop for D

        # 复位D网络的梯度
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        
        # 计算D网络的梯度
        self.backward_D()                # calculate gradients for D

        # 进行D网络的迭代迭代
        self.optimizer_D.step()          # update D's weights

        # update G
        # D requires no gradients when optimizing G
        # 需要手工锁定D网络
        self.set_requires_grad(self.netD, False) 
        
        # 复位G网梯度
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        
        # 计算G网络的新的梯度
        self.backward_G()                   # calculate graidents for G
        
        # G网络迭代迭代
        self.optimizer_G.step()             # udpate G's weights
  • 先优化判决网络D,鉴别出真假
  • 然后优化生成网络G, 骗过优化后的D网络
  • 然后在优化判决网络D,鉴别出真假
  • 依次类推,不断对抗、优化、迭代、更新,直到D网络无法判决出G网络输出的真假,得到以假乱真的效果。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122101231

以上是关于[Pytorch系列-74]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - pix2pix网络结构与代码实现详解的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

[Pytorch系列-61]:生成对抗网络GAN - 基本原理 - 自动生成手写数字案例分析

[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解

[Pytorch系列-63]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 代码总体架构与总体学习思路

[Pytorch系列-73]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - Train.py代码详解

[Pytorch系列-65]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 无监督图像生成CycleGan的基本原理