[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
目录
第1章 pix2pix测试代码
1.1 代码路径
.\\pytorch-CycleGAN-and-pix2pix\\test
1.2 关键命令行参数(以pix2pix为例)
-dataroot ./datasets/facades --direction BtoA --model pix2pix --name facades_pix2pix --verbose
其中 --verbose:表示打印网络架构
第2章 测试代码主要流程
(1)获取命令行参数:opt = TestOptions().parse()
(2)设置test模式下命令命令行参数
(3)创建数据集:dataset = create_dataset(opt)
(4)创建模型pix2pix模型: model = create_model(opt)
(5)加载预预训练模型参数:model.setup(opt)
预训练模型的位置有opt参数指定。
print(model.model_names)
print(model.visual_names)
[Network G] Total number of parameters : 54.414 M
-----------------------------------------------
['G']
['real_A', 'fake_B', 'real_B']
备注:
- pix2pix的test,只有G网络,没有D网络。
- 在pix2pix网络中,有三个关键点的图片:['real_A', 'fake_B', 'real_B'],其中real_A为输入图片,fake_B为创作生成的图片,real_B为与创作对应的真实图片,用它与创作图片进行比较和验证,以确定输出图片是否精准。
(6)构建web输出结构
(7)设置在评估模式:model.eval()
(8)读取数据集:for i, data in enumerate(dataset):
(9)unpack成对数据:model.set_input(data)
(10)根据输入图片,生成创作图片:model.test() , =》 调用foward() 进行前向预测
而前向预测只有G网络,没有D网络:
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A)
(11)获取输出图片:visuals = model.get_current_visuals(),图片包括:real_A(输入图片), fake_B(创作图片), real_B (用于人工比较)这三部分。
(12)存储图片:save_images =》 pytorch-CycleGAN-and-pix2pix\\results\\facades_pix2pix\\test_latest\\images\\xxx
(13)存储Website:webpage.save() =》pytorch-CycleGAN-and-pix2pix\\results\\facades_pix2pix\\test_latest\\index.htlm
第3章 代码详解
# 获取命令行参数
opt = TestOptions().parse() # get test options
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a html file.
# 创建数据集
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
# 创建模型
model = create_model(opt) # create a model given opt.model and other options
# 加载预训练模型
model.setup(opt) # regular setup: load and print networks; create schedulers
# 构建web输出
# create a website
web_dir = os.path.join(opt.results_dir, opt.name, '_'.format(opt.phase, opt.epoch)) # define the website directory
if opt.load_iter > 0: # load_iter is 0 by default
web_dir = ':s_iter:d'.format(web_dir, opt.load_iter)
print('creating web directory', web_dir)
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
# test with eval mode. This only affects layers like batchnorm and dropout.
# For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
# For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
# 设置在评估模式
if opt.eval:
model.eval()
# 读取数据集数据
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
# unpack成对数据
model.set_input(data) # unpack data from data loader
# 生成图片
model.test() # run inference
# 获取输出图片:real_A(输入图片), fake_B(创作图片), real_B (用于人工比较)
visuals = model.get_current_visuals() # get image results
# 获取存放图片的路径
img_path = model.get_image_paths() # get image paths
# 每运行5次,打印一次提示
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
#存储图片
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
# 图片信息存放到web中,用于可视化显示
webpage.save() # save the HTML
----------------- Options ---------------
aspect_ratio: 1.0
batch_size: 1
checkpoints_dir: ./checkpoints
crop_size: 256
dataroot: ./datasets/facades [default: None]
dataset_mode: aligned
direction: BtoA [default: AtoB]
display_winsize: 256
epoch: latest
eval: False
gpu_ids: 0
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: False [default: None]
load_iter: 0 [default: 0]
load_size: 256
max_dataset_size: inf
model: pix2pix [default: test]
n_layers_D: 3
name: facades_pix2pix [default: experiment_name]
ndf: 64
netD: basic
netG: unet_256
ngf: 64
no_dropout: False
no_flip: False
norm: batch
ntest: inf
num_test: 50
num_threads: 4
output_nc: 3
phase: test
preprocess: resize_and_crop
results_dir: ./results/
serial_batches: False
suffix:
verbose: True [default: False]
----------------- End -------------------
dataset [AlignedDataset] was created
initialize network with normal
model [Pix2PixModel] was created
loading the model from ./checkpoints\\facades_pix2pix\\latest_net_G.pth
---------- Networks initialized -------------
DataParallel(
(module): UnetGenerator(
(model): UnetSkipConnectionBlock(
(model): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): UnetSkipConnectionBlock(
(model): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ReLU(inplace=True)
(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): Tanh()
)
)
)
)
[Network G] Total number of parameters : 54.414 M
-----------------------------------------------
['G']
['real_A', 'fake_B', 'real_B']
creating web directory ./results/facades_pix2pix\\test_latest
C:\\ProgramData\\Anaconda3\\envs\\pytorch-gpu-os\\lib\\site-packages\\torchvision\\transforms\\transforms.py:287: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
warnings.warn(
processing (0000)-th image... ['./datasets/facades\\\\test\\\\1.jpg']
processing (0005)-th image... ['./datasets/facades\\\\test\\\\103.jpg']
processing (0010)-th image... ['./datasets/facades\\\\test\\\\12.jpg']
processing (0015)-th image... ['./datasets/facades\\\\test\\\\17.jpg']
processing (0020)-th image... ['./datasets/facades\\\\test\\\\21.jpg']
processing (0025)-th image... ['./datasets/facades\\\\test\\\\26.jpg']
processing (0030)-th image... ['./datasets/facades\\\\test\\\\30.jpg']
processing (0035)-th image... ['./datasets/facades\\\\test\\\\35.jpg']
processing (0040)-th image... ['./datasets/facades\\\\test\\\\4.jpg']
processing (0045)-th image... ['./datasets/facades\\\\test\\\\44.jpg']
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122075510
以上是关于[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-61]:生成对抗网络GAN - 基本原理 - 自动生成手写数字案例分析
[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解
[Pytorch系列-63]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 代码总体架构与总体学习思路
[Pytorch系列-73]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - Train.py代码详解
[Pytorch系列-65]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 无监督图像生成CycleGan的基本原理
[Pytorch系列-68]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试CycleGAN模型