从零开始实现一个简单的CycleGAN项目

Posted 江户川柯壮

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从零开始实现一个简单的CycleGAN项目相关的知识,希望对你有一定的参考价值。

项目地址:https://github.com/jzsherlock4869/cyclegan-pytorch

pytorch 中CycleGAN(循环一致生成对抗网络)的简单且易于修改的实现

CycleGAN 的基本说明(来自原始论文):

使用“horse2zebra”数据集重新实现此 repo 的结果(没有参数调整,仍然有一些明显的伪影,您可以调整超参数以使其更好~)

安装和准备
下载 CycleGAN 的常用数据集并使用它们来训练和验证代码管道

Monet-Photo 传输:Kaggle Monet-Photo 传输
Horse-Zebra 转移:Kaggle Horse-Zebra 转移
然后按以下结构准备数据集文件夹:

├── monet_dataset
│ ├── monet_jpg
│ └── photo_jpg
└── zebra_dataset
├── testA
├── testB
├── trainA
└── trainB
Git 克隆这个 repo 并 cd 到根文件夹

git clone https://github.com/jzsherlock4869/cyclegan-pytorch
cd cyclegan-python

根据requirements.txt文件夹中的安装必要的python包

开始训练
将 dataroot 修改config/000_base_cyclegan_horse2zebra.yml为您自己的数据集路径,然后运行训练过程:

python train_cyclegan.py --opt configs/000_base_cyclegan_horse2zebra.yml

“ horse2zebra ”和“ photo2monet ”的数据集类已经在这个 repo 中实现。

get_photo2monet_train_dataloader要在您自己的数据集(域 A 和域 B)上进行训练,请将您自己的数据加载器编写 get_horse2zebra_train_dataloader为 data/sample_dataloader.py

def get_your_custom_train_dataloader(root_dir="your_path", 
                                    batch_size=8, 
                                    img_size=(256, 256)):
    imgA_sub, imgB_sub = "subdirnameA", "subdirnameB" # sub directory name to your root_dir
    postfix_set=["jpg"]  # which postfix is your images
    train_dataset = CycleGANDataset(root_dir, imgA_sub, imgB_sub, postfix_set, img_size)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    return train_dataloader

然后修改第train_cyclegan.py54-60 行以添加您的数据集(记得先导入它们!)

if which_dataset == 'horse2zebra':
    train_dataloader = get_horse2zebra_train_dataloader(dataroot, 
                                                        batch_size=batch_size, 
                                                        img_size=img_size)
elif which_dataset == 'photo2monet':
    train_dataloader = get_photo2monet_train_dataloader(dataroot, 
                                                        batch_size=batch_size, 
                                                        img_size=img_size)
   #  add lines here
 elif which_dataset == 'your_custom_dataset':
    train_dataloader = get_your_custom_train_dataloader(dataroot, 
                                                        batch_size=batch_size, 
                                                        img_size=img_size)
#add lines here
else:
    raise NotImplementedError(f"Unrecognized dataset type : which_dataset")

参考
这段代码是对 CycleGAN 的重新实现,更易于理解和修改,尤其适合初学者。原论文是:

@inproceedingsCycleGAN2017, title=Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss, author=Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A, booktitle=Computer Vision (ICCV), 2017 IEEE International Conference on, year=2017

一个keras版本和教程,详细解释了CycleGAN的理论和过程:

https://machinelearningmastery.com/cyclegan-tutorial-with-keras/

而且,这个代码库的代码架构和风格也参考了BasicSR和UnpairedSR,部分功能直接借鉴。欣赏他们的好作品~

欢迎star⭐如果这个 repo 对你有帮助:)

以上是关于从零开始实现一个简单的CycleGAN项目的主要内容,如果未能解决你的问题,请参考以下文章

从零开始手写Tomcat的教程---未完待续

Android项目实战 | 从零开始写app实现服务端智慧服务页面数据的解析

从零开始,编写简单的课程信息管理系统(使用jsp+servlet+javabean架构)

从零开始实现数据结构 动态数组

从零开始写 OS 内核 - 运行 shell

从零开始写 OS 内核 - 键盘驱动