pytorch 数据集中每个类的实例数

Posted

技术标签:

【中文标题】pytorch 数据集中每个类的实例数【英文标题】:Number of instances per class in pytorch dataset 【发布时间】:2020-09-30 19:01:11 【问题描述】:

我正在尝试使用 PyTorch 制作一个简单的图像分类器。 这就是我将数据加载到数据集和 dataLoader 中的方式:

batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])

dataset = ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))

train_indices = indices[:int(len(indices)*0.8)] 
test_indices = indices[int(len(indices)*0.8):]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)

我想分别打印出训练和测试数据中每个类的图像数量,如下所示:

在火车数据中:

鞋子:20 衬衫:14

在测试数据中:

鞋:4 衬衫:3

我试过了:

from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))

但我收到了这个错误:

AttributeError: 'MyDataset' object has no attribute 'img'

【问题讨论】:

可能的解决方案:discuss.pytorch.org/t/… 【参考方案1】:

您需要使用.targets 来访问数据的标签,即

print(dict(Counter(dataset.targets)))

它会打印这样的东西(例如在 MNIST 数据集中):

5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851

另外,您可以使用.classes.class_to_idx 获取标签ID 到类的映射:

print(dataset.class_to_idx)
'0 - zero': 0,
 '1 - one': 1,
 '2 - two': 2,
 '3 - three': 3,
 '4 - four': 4,
 '5 - five': 5,
 '6 - six': 6,
 '7 - seven': 7,
 '8 - eight': 8,
 '9 - nine': 9

编辑:方法 1

从 cmets 中,为了分别获得训练集和测试集的类分布,您可以简单地迭代子集,如下所示:

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# labels in training set
train_classes = [label for _, label in train_dataset]
Counter(train_classes)
Counter(0: 4757,
         1: 5363,
         2: 4782,
         3: 4874,
         4: 4678,
         5: 4321,
         6: 4747,
         7: 5024,
         8: 4684,
         9: 4770)

编辑(2):方法2

由于您有一个大型数据集,并且正如您所说,迭代所有训练集需要相当长的时间,还有另一种方法:

您可以使用子集的.indices,它指的是为子集选择的原始数据集中的索引。

train_classes = [dataset.targets[i] for i in train_dataset.indices]
Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)

【讨论】:

您必须将原始数据集拆分为训练集和测试集,然后您才能访问它(我认为您不能从数据加载器访问它)例如***.com/a/51768651/6210807 哦,不过,您可以简单地遍历子集然后获取类,检查我的编辑。 我会检查是否有其他方法,同时,您可以这样做:因为您可以使用 .targets 在完整数据集上分配类,所以在测试数据集上运行上述循环(它'将花费更少的时间),并从总数据集中减去类,这样您将获得训练集上的类分布。 是的,这是有道理的,我很高兴从您那里听到更好的方法。 @AminBashiri 尝试方法 2(使用 .indices)。检查编辑。【参考方案2】:

简单易行 如果你有dataset 类,在你的情况下是ImageFolder

dataset = MyDataset() # which in your case in ImageFolder
labels = torch.zeros(num_classes, dtype=torch.long)

for _, target in dataset:
    labels += target

【讨论】:

以上是关于pytorch 数据集中每个类的实例数的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 Pytorch 将增强图像添加到原始数据集中?

Pytorch 中的标注:多目标数据集的不一致增强

pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

PyTorch版Mask R-CNN图像实例分割实战:训练自己的数据集

PyTorch Big Graph 嵌入数据集中优化器 state_dict 的目的是啥?

pytorch之wod2vec实现