如何使用 torch Dataloader 获取具有相同类的图片?
Posted
技术标签:
【中文标题】如何使用 torch Dataloader 获取具有相同类的图片?【英文标题】:How can I use torch Dataloader to get Picture having same class? 【发布时间】:2021-05-08 08:08:49 【问题描述】:在我的数据集中,有 6 个类,每个类有 23 张图片
我用torchvision.dataset
制作了ImageFolder
,效果很好。
dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
transform = vision_trans.Compose([
vision_trans.Resize(256),
vision_trans.CenterCrop(256),
vision_trans.ToTensor()
]))
dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
shuffle = False, num_workers = 2, )
但我想获得具有相同类别的批量图像。
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...
这就是我想要的标签(批量数据的类)形式 但实际上 DataLoader 会这样工作
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...
如何获取每个标签的批次数据?
【问题讨论】:
【参考方案1】:ImageFolder
无法方便地做到这一点。您应该为每个类创建一个数据集,并从您需要的数据集中加载批次。
更具体地说,假设您的文件夹结构是ImageFolder 所要求的,您需要创建一个小型数据集类:
class ImageSubFolder(torch.utils.data.Dataset):
def __init__(self, root_dir, label):
# Path toward the label-sorted subfolders of your dataset
# Assuming images are named smthg like /path/to/label/xxxx.npy
self._path = root_dir + label+ ":04d"
def __len__(self):
return count_files_in_directory(self._path)
def __getitem__(self, index):
return (np.load(self._path.format(index), label)
这只是为了展示类的逻辑,我相信你还有一些功能需要实现(你可以关注this tutorial)。 “其余要实现的功能留给读者作为练习”。无论如何,对于这个类,您只需要创建它的 6 个实例(每个类一个):
loaders =
for label in ("dog", "cat", "plane", "tree", "mug", "car"):
dataset = SubFolderDataset(DATA_ROOT, label)
loaders[label] = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,shuffle = False, num_workers = 2, )
现在你有一个包含数据加载器的字典,它只加载给定类的样本。
【讨论】:
以上是关于如何使用 torch Dataloader 获取具有相同类的图片?的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch之torch.utils.data.DataLoader解读
PyTorch源码解读之torch.utils.data.DataLoader(转)