如何使用 torch Dataloader 获取具有相同类的图片?

Posted

技术标签:

【中文标题】如何使用 torch Dataloader 获取具有相同类的图片?【英文标题】:How can I use torch Dataloader to get Picture having same class? 【发布时间】:2021-05-08 08:08:49 【问题描述】:

在我的数据集中,有 6 个类,每个类有 23 张图片 我用torchvision.dataset 制作了ImageFolder,效果很好。

dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
                                     transform = vision_trans.Compose([
                                                    vision_trans.Resize(256),
                                                    vision_trans.CenterCrop(256),
                                                    vision_trans.ToTensor()
                                     ]))

dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
                                         shuffle = False, num_workers = 2, )

但我想获得具有相同类别的批量图像。

...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...

这就是我想要的标签(批量数据的类)形式 但实际上 DataLoader 会这样工作

...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...

如何获取每个标签的批次数据?

【问题讨论】:

【参考方案1】:

ImageFolder 无法方便地做到这一点。您应该为每个类创建一个数据集,并从您需要的数据集中加载批次。

更具体地说,假设您的文件夹结构是ImageFolder 所要求的,您需要创建一个小型数据集类:

class ImageSubFolder(torch.utils.data.Dataset):
    def __init__(self, root_dir, label):
        # Path toward the label-sorted subfolders of your dataset
        # Assuming images are named smthg like /path/to/label/xxxx.npy
        self._path = root_dir + label+ ":04d"

    def __len__(self):
        return count_files_in_directory(self._path)

    def __getitem__(self, index):
        return (np.load(self._path.format(index), label)

这只是为了展示类的逻辑,我相信你还有一些功能需要实现(你可以关注this tutorial)。 “其余要实现的功能留给读者作为练习”。无论如何,对于这个类,您只需要创建它的 6 个实例(每个类一个):

loaders = 
for label in ("dog", "cat", "plane", "tree", "mug", "car"):
    dataset = SubFolderDataset(DATA_ROOT, label)
    loaders[label] = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,shuffle = False, num_workers = 2, )

现在你有一个包含数据加载器的字典,它只加载给定类的样本。

【讨论】:

以上是关于如何使用 torch Dataloader 获取具有相同类的图片?的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch之torch.utils.data.DataLoader解读

PyTorch源码解读之torch.utils.data.DataLoader(转)

优化pytorch DataLoader提升数据加载速度

torch_12_dataset和dataLoader

torch.utils.data.DataLoader()详解Pytorch入门手册

如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?