如何在 PyTorch 中对子集使用不同的数据增强

Posted

技术标签:

【中文标题】如何在 PyTorch 中对子集使用不同的数据增强【英文标题】:How to use different data augmentation for Subsets in PyTorch 【发布时间】:2019-01-17 19:18:35 【问题描述】:

如何在 PyTorch 中为不同的Subsets 使用不同的数据增强(转换)?

例如:

train, test = torch.utils.data.random_split(dataset, [80000, 2000])

traintest 将具有与 dataset 相同的转换。如何对这些子集使用自定义转换?

【问题讨论】:

【参考方案1】:

我目前的解决方案不是很优雅,但很有效:

from copy import copy

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)

test_dataset.dataset.transform = transforms.Compose([
    transforms.Resize(img_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset.dataset.transform = transforms.Compose([
    transforms.RandomResizedCrop(img_resolution[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

基本上,我正在为其中一个拆分定义一个新数据集(它是原始数据集的副本),然后为每个拆分定义一个自定义转换。

注意:train_dataset.dataset.transform 有效,因为我使用的是 ImageFolder 数据集,它使用 .tranform 属性来执行转换。

如果有人知道更好的解决方案,请与我们分享!

【讨论】:

是的,PyTorch 数据集 API 有点简陋。内置数据集没有相同的属性,一些转换仅用于 PIL 图像,一些仅用于数组,Subset 不委托给包装的数据集......我希望这会在未来改变,但现在我没有'不认为有更好的方法来做到这一点【参考方案2】:

我已经放弃并复​​制了我自己的子集(几乎与 pytorch 相同)。我将转换保留在子集中(而不是父集)。

class Subset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        im, labels = self.dataset[self.indices[idx]]
        return self.transform(im), labels

    def __len__(self):
        return len(self.indices)

您还必须编写自己的拆分函数

【讨论】:

【参考方案3】:

这是我用的(取自here):

import torch
from torch.utils.data import Dataset, TensorDataset, random_split
from torchvision import transforms

class DatasetFromSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

这是一个例子:

init_dataset = TensorDataset(
    torch.randn(100, 3, 24, 24),
    torch.randint(0, 10, (100,))
)

lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)]
train_subset, test_subset = random_split(init_dataset, lengths)

train_dataset = DatasetFromSubset(
    train_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)
test_dataset = DatasetFromSubset(
    test_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)

【讨论】:

以上是关于如何在 PyTorch 中对子集使用不同的数据增强的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Pytorch 中使用 torchvision.transforms 对分割任务进行数据增强?

如何在 PySpark 中的大型 Spark 数据框中对行的每个子集进行映射操作

如何使用 Pytorch 将增强图像添加到原始数据集中?

Pytorch 中的标注:多目标数据集的不一致增强

pytorch 蛋白增强 p 值?

如何在Python中对满足某些条件的行进行子集[重复]