cycleGAN源码解读:数据读取

Posted wzyuan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了cycleGAN源码解读:数据读取相关的知识,希望对你有一定的参考价值。

源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

数据的读取是比较简单的,cycleGAN对数据没有pair的需求,不同域的两个数据集分别存放于A,B两个文件夹,写好dataset接口即可

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(opt.size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), 
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

上面的代码中,首先定义好buffer(后面细说),然后定义好图像变换,调用定义好的ImageDataset(继承自dataset) 对象,即可从dataloader中读取数据。下面是ImageDataset的代码

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode=train):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, %s/A % mode) + /*.*))
        self.files_B = sorted(glob.glob(os.path.join(root, %s/B % mode) + /*.*))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {A: item_A, B: item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

标准的实现了__init__, __getitem__, __len__三个接口,不过我还不太清楚这里对数据进行排序和对齐的目的,对齐可以按序读取,不对齐则随机读取最后,关于buffer,参考cycleGAN的论文,原话是这么说的“Second, to reduce model oscillation [15], we follow Shrivastava et al.’s strategy [46] and update the discriminators using a history of generated images rather than the ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images

也就是说,是为了训练的稳定,采用历史生成的虚假样本来更新判别器,而不是当前生成的虚假样本,至于原理,参考的是另一篇论文。我们来看一下代码

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), Empty buffer or trying to create a black hole. Be careful.
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

定义了一个buffer对象,有一个数据存储表data,大小预设为50,我认为它的运转流程是这样的:数据表未填满时,每次读取的都是当前生成的虚假图像,当数据表填满时,随机决定 1. 在数据表中随机抽取一批数据,返回,并且用当前数据补充进来 2. 采用当前数据

至于为什么这样有道理,要看参考论文了

 

 

以上是关于cycleGAN源码解读:数据读取的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-74]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - pix2pix网络结构与代码实现详解

论文解读 用于弱监督表面缺陷分割的缺陷注意模板循环对抗网络 (Defect attention template generation cycleGAN for weakly supervised)

使用CycleGAN训练自己制作的数据集,通俗教程,快速上手

StackExchange.Redis.Extensions.Core 源码解读之 Configuration用法

Flink读取Iceberg表的实现源码解读

Flink读取Iceberg表的实现源码解读