[Pytorch系列-71]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练pix2pix模型
Posted 文火冰糖的硅基工坊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[Pytorch系列-71]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练pix2pix模型相关的知识,希望对你有一定的参考价值。
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122077757
目录
第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码
第1章 概述
1.1 代码架构与总体思路
1.2 本章基本思路
(1)Pycharm进行调试,替代命令行或Jupter
(2)选择所需要硬盘空间小的数据进行测试
(3)熟悉pytorch-CycleGAN-and-pix2pix项目的使用
(4)熟悉pix2pix模型训练
1.3 训练方式
- 从头开始训练
- 从预预训练模型开始训练(官网提供的预训练模型只包括G网络,不包括D网络)
- 从上次训练结果开始训练
第2章 测试步骤
第1步:下载或克隆pytorch-CycleGAN-and-pix2pix所有代码
如果已经完成,可以跳过此步骤。
(1)Linux 命令行方式:!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
(2)Windows浏览器下载:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
备注:
-
可以把代码下载或拷贝到jupter的工作目录中,以便后续可以通过jupter运行代码。
第2步:切换当前目录
(1)运行方式
- Windows 命令行方式:cd xxx
- jupter方式:
import os
os.chdir('pytorch-CycleGAN-and-pix2pix/')
- Pycharm: 把工程文件copy到Pycharm工作目录中即可
第3步:安装依赖文件(可视化工具)
如果已经完成,可以跳过此步骤。
- Windows 命令行方式
pip install -r requirements.txt
- Jupter方式
!pip install -r requirements.txt
torch>=0.4.1
torchvision>=0.2.1
dominate>=2.3.1
visdom>=0.1.8.3
第4步:下载pix2pix数据集
(1)下载方式
- Linux 命令行方式
bash ./datasets/download_pix2pix_dataset.sh facades
- Jupter方式
!bash ./datasets/download_pix2pix_dataset.sh facades
- Windows浏览器方式
根据./datasets/download_pix2pix_dataset.sh的内容,获取数据集URL, 通过URL手工下载:Index of /pix2pix/datasets
备注:
- 有些数据集很多,高达8G, 下载时需留意硬盘空间是否可以承载。
- pix2pix的数据集是成对出现的。
- Facades和cityscapes数据集最小,方便测试验证。
(2)数据集的存放路径
- 存放路径:pytorch-CycleGAN-and-pix2pix\\datasets
备注:必须同名,不能改名
(3)支持的数据集
支持的数据集:
- cityscapes: 城市轮廓转换成城市街景实体
- night2day: 晚上转换成白天
- edges2handbags:边沿转换成手提包
- edges2shoes:边沿转换成鞋子
- facades:房屋外观转换成房子实体(所需要的内存空间最小)
- maps:地图轮廓转换成实体地图
第5步:下载预训练模型
(1)下载方式
- Linux命令行方式
bash ./scripts/download_pix2pix_model.sh facades_label2photo
- jupter方式
!bash ./scripts/download_pix2pix_model.sh facades_label2photo
- Windows方式
根据download_pix2pix_model.sh脚步的内容,获取链接:
http://efrosgans.eecs.berkeley.edu/pix2pix/models-pytorch/
(2)存放路径
./checkpoints/xxx/latest_net_G.pth
xxx为模型名称。
备注:
- 需要把模型的名称,改为latest_net_G.pth,并存放在xxx目录中,这与使用预训练模型进行测试是不一样的。
- 官方的预训练模型,只有G网络的参数,没有D网络的参数,因此基本上需要重新训练。
第6步:启动可视化工具visdom
(1)启动visdom server
conda info -e
conda activate pytorch-gpu-os
python -m visdom.server
(2)启动visdom Client
http://localhost:8097
第7步:模型训练
(1)CPU方式(仅用于学习代码)运行
--dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
--gpu_ids -1 --niter_decay 1 --niter 1
--display_freq 1 --update_html_freq 1 --print_freq 1 --save_epoch_freq 5 --save_latest_freq 100
- --gpu_ids -1:表示使用CPU进行训练。
- -print_freq 1:每迭代多少次,在终端上打印一次提示信息, 默认100.
- -display_freq 1:每迭代多少次,在visdom客户端可视化一次图像,默认400
- --update_html_freq 1:每迭代多少次,更新一次html输出文件,默认1000.
- --save_epoch_freq 5: 每迭代多少次,存储一次模型参数
- --save_latest_freq 100:每迭代多少次,存储一次模型参数
- --niter 1:迭代的epoch次数, 默认100
- --niter_decay 1:迭代的epoch次数,对学习率进行一次衰减,默认100,总的epoc=niter + niter_decay + 1
备注:
- 该项目,采用GPU训练时,需要>8G的GPU内存,如果GPU条件不满足,在学习代码流程时,可以使用CPU进行训练
- 之所以修改这些默认参数,是因为CPU的训练太慢,不利于学习的效率。
(2)GPU方式(适用于正式训练模型)
--dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
备注:
在GPU的情况下,使用默认的参数。
(3)重头训练与基于先前的训练结果继续训练
--continue_train :如果设置,则基于先前的训练结果继续训练,如果不设置,则从头开始训练。
第8步:效果展示
(1)控制台打印显示
dataset [AlignedDataset] was created
The number of training images = 400
initialize network with normal
initialize network with normal
model [Pix2PixModel] was created
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.414 M
[Network D] Total number of parameters : 2.769 M
-----------------------------------------------
(epoch: 1, iters: 1, time: 0.861, data: 8.576) G_GAN: 1.932 G_L1: 35.501 D_real: 0.595 D_fake: 1.196
(epoch: 1, iters: 2, time: 0.861, data: 0.001) G_GAN: 1.457 G_L1: 52.500 D_real: 1.197 D_fake: 1.248
(epoch: 1, iters: 3, time: 0.791, data: 0.000) G_GAN: 0.957 G_L1: 39.967 D_real: 1.115 D_fake: 1.119
............................
(epoch: 2, iters: 400, time: 0.973, data: 0.000) G_GAN: 1.775 G_L1: 39.182 D_real: 0.137 D_fake: 0.277
End of epoch 2 / 2 Time Taken: 408 sec
learning rate = 0.0000000
Process finished with exit code 0
(2)visdom图形化显示
- loss
- 训练结果
第9步:输出文件
(1)图片文件:
目录:
- pytorch-CycleGAN-and-pix2pix\\checkpoints\\facades_pix2pix\\web\\images
(2)模型文件
目录:
- pytorch-CycleGAN-and-pix2pix\\checkpoints\\facades_pix2pix\\
内容:
- latest_net_G.pth:最新的G网络模型文件, 再训练时,可以 基于此文件进行继续训练
- latest_net_D.pth:最新的D网络模型文件,再训练时,可以 基于此文件进行继续训练
- 200_net_G.pth:迭代n次的G网络模型文件
- 200_net_D.pth:迭代n次的D网络模型文件
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122077757
以上是关于[Pytorch系列-71]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练pix2pix模型的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解
[Pytorch系列-61]:生成对抗网络GAN - 基本原理 - 自动生成手写数字案例分析
[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解
[Pytorch系列-63]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 代码总体架构与总体学习思路
[Pytorch系列-73]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - Train.py代码详解
[Pytorch系列-65]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 无监督图像生成CycleGan的基本原理