[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

Posted 文火冰糖的硅基工坊

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解相关的知识,希望对你有一定的参考价值。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - pix2pix test代码详解_文火冰糖(王文兵)的博客-CSDN博客


目录

第1章 pix2pix测试代码

1.1 代码路径

1.2 关键命令行参数

第2章 测试代码主要流程

第3章 代码详解


第1章 pix2pix测试代码

1.1 代码路径

.\\pytorch-CycleGAN-and-pix2pix\\test

1.2 关键命令行参数(以pix2pix为例)

-dataroot ./datasets/facades --direction BtoA --model pix2pix --name facades_pix2pix --verbose

其中 --verbose:表示打印网络架构

第2章 测试代码主要流程

(1)获取命令行参数:opt = TestOptions().parse() 

(2)设置test模式下命令命令行参数

(3)创建数据集:dataset = create_dataset(opt) 

(4)创建模型pix2pix模型: model = create_model(opt) 

(5)加载预预训练模型参数:model.setup(opt) 

预训练模型的位置有opt参数指定。

print(model.model_names)
print(model.visual_names)

[Network G] Total number of parameters : 54.414 M
-----------------------------------------------
['G']
['real_A', 'fake_B', 'real_B']

备注:

  • pix2pix的test,只有G网络,没有D网络。
  • 在pix2pix网络中,有三个关键点的图片:['real_A', 'fake_B', 'real_B'],其中real_A为输入图片,fake_B为创作生成的图片,real_B为与创作对应的真实图片,用它与创作图片进行比较和验证,以确定输出图片是否精准。

(6)构建web输出结构

(7)设置在评估模式:model.eval()

(8)读取数据集:for i, data in enumerate(dataset):

(9)unpack成对数据:model.set_input(data)

(10)根据输入图片,生成创作图片:model.test() , =》 调用foward() 进行前向预测

而前向预测只有G网络,没有D网络:

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

(11)获取输出图片:visuals = model.get_current_visuals(),图片包括:real_A(输入图片), fake_B(创作图片), real_B (用于人工比较)这三部分。

(12)存储图片:save_images  =》 pytorch-CycleGAN-and-pix2pix\\results\\facades_pix2pix\\test_latest\\images\\xxx

(13)存储Website:webpage.save()   =》pytorch-CycleGAN-and-pix2pix\\results\\facades_pix2pix\\test_latest\\index.htlm

第3章 代码详解

   # 获取命令行参数
    opt = TestOptions().parse()  # get test options
    
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a html file.

    # 创建数据集
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    
    # 创建模型
    model = create_model(opt)      # create a model given opt.model and other options

    # 加载预训练模型
    model.setup(opt)               # regular setup: load and print networks; create schedulers

    # 构建web输出
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '_'.format(opt.phase, opt.epoch))  # define the website directory
    if opt.load_iter > 0:  # load_iter is 0 by default
        web_dir = ':s_iter:d'.format(web_dir, opt.load_iter)
    print('creating web directory', web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))

    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.

    # 设置在评估模式
    if opt.eval:
        model.eval()

    # 读取数据集数据
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
            
        # unpack成对数据
        model.set_input(data)  # unpack data from data loader
        
        # 生成图片
        model.test()           # run inference
        
        # 获取输出图片:real_A(输入图片), fake_B(创作图片), real_B (用于人工比较)
        visuals = model.get_current_visuals()  # get image results
        
        # 获取存放图片的路径
        img_path = model.get_image_paths()     # get image paths
        
        # 每运行5次,打印一次提示
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
            
        #存储图片
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    
    # 图片信息存放到web中,用于可视化显示
    webpage.save()  # save the HTML

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: ./datasets/facades                [default: None]
             dataset_mode: aligned                       
                direction: BtoA                              [default: AtoB]
          display_winsize: 256                           
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: False                             [default: None]
                load_iter: 0                                 [default: 0]
                load_size: 256                           
         max_dataset_size: inf                           
                    model: pix2pix                           [default: test]
               n_layers_D: 3                             
                     name: facades_pix2pix                   [default: experiment_name]
                      ndf: 64                            
                     netD: basic                         
                     netG: unet_256                      
                      ngf: 64                            
               no_dropout: False                         
                  no_flip: False                         
                     norm: batch                         
                    ntest: inf                           
                 num_test: 50                            
              num_threads: 4                             
                output_nc: 3                             
                    phase: test                          
               preprocess: resize_and_crop               
              results_dir: ./results/                    
           serial_batches: False                         
                   suffix:                               
                  verbose: True                              [default: False]
----------------- End -------------------
dataset [AlignedDataset] was created
initialize network with normal
model [Pix2PixModel] was created
loading the model from ./checkpoints\\facades_pix2pix\\latest_net_G.pth
---------- Networks initialized -------------
DataParallel(
  (module): UnetGenerator(
    (model): UnetSkipConnectionBlock(
      (model): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): UnetSkipConnectionBlock(
          (model): Sequential(
            (0): LeakyReLU(negative_slope=0.2, inplace=True)
            (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
            (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (3): UnetSkipConnectionBlock(
              (model): Sequential(
                (0): LeakyReLU(negative_slope=0.2, inplace=True)
                (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (3): UnetSkipConnectionBlock(
                  (model): Sequential(
                    (0): LeakyReLU(negative_slope=0.2, inplace=True)
                    (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                    (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (3): UnetSkipConnectionBlock(
                      (model): Sequential(
                        (0): LeakyReLU(negative_slope=0.2, inplace=True)
                        (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                        (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                        (3): UnetSkipConnectionBlock(
                          (model): Sequential(
                            (0): LeakyReLU(negative_slope=0.2, inplace=True)
                            (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                            (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                            (3): UnetSkipConnectionBlock(
                              (model): Sequential(
                                (0): LeakyReLU(negative_slope=0.2, inplace=True)
                                (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                (3): UnetSkipConnectionBlock(
                                  (model): Sequential(
                                    (0): LeakyReLU(negative_slope=0.2, inplace=True)
                                    (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                    (2): ReLU(inplace=True)
                                    (3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                  )
                                )
                                (4): ReLU(inplace=True)
                                (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                (7): Dropout(p=0.5, inplace=False)
                              )
                            )
                            (4): ReLU(inplace=True)
                            (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                            (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                            (7): Dropout(p=0.5, inplace=False)
                          )
                        )
                        (4): ReLU(inplace=True)
                        (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                        (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                        (7): Dropout(p=0.5, inplace=False)
                      )
                    )
                    (4): ReLU(inplace=True)
                    (5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  )
                )
                (4): ReLU(inplace=True)
                (5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
            )
            (4): ReLU(inplace=True)
            (5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
            (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (2): ReLU(inplace=True)
        (3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (4): Tanh()
      )
    )
  )
)
[Network G] Total number of parameters : 54.414 M
-----------------------------------------------
['G']
['real_A', 'fake_B', 'real_B']

creating web directory ./results/facades_pix2pix\\test_latest
C:\\ProgramData\\Anaconda3\\envs\\pytorch-gpu-os\\lib\\site-packages\\torchvision\\transforms\\transforms.py:287: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  warnings.warn(
processing (0000)-th image... ['./datasets/facades\\\\test\\\\1.jpg']
processing (0005)-th image... ['./datasets/facades\\\\test\\\\103.jpg']
processing (0010)-th image... ['./datasets/facades\\\\test\\\\12.jpg']
processing (0015)-th image... ['./datasets/facades\\\\test\\\\17.jpg']
processing (0020)-th image... ['./datasets/facades\\\\test\\\\21.jpg']
processing (0025)-th image... ['./datasets/facades\\\\test\\\\26.jpg']
processing (0030)-th image... ['./datasets/facades\\\\test\\\\30.jpg']
processing (0035)-th image... ['./datasets/facades\\\\test\\\\35.jpg']
processing (0040)-th image... ['./datasets/facades\\\\test\\\\4.jpg']
processing (0045)-th image... ['./datasets/facades\\\\test\\\\44.jpg']


作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122075510

以上是关于[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-61]:生成对抗网络GAN - 基本原理 - 自动生成手写数字案例分析

[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解

[Pytorch系列-63]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 代码总体架构与总体学习思路

[Pytorch系列-73]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - Train.py代码详解

[Pytorch系列-65]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 无监督图像生成CycleGan的基本原理

[Pytorch系列-68]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试CycleGAN模型