__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?

Posted

技术标签:

【中文标题】__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?【英文标题】:How does the __getitem__'s idx work within PyTorch's DataLoader? 【发布时间】:2020-03-09 02:05:32 【问题描述】:

我目前正在尝试使用 PyTorch 的 DataLoader 处理数据以输入我的深度学习模型,但遇到了一些困难。

我需要的数据是(minibatch_size=32, rows=100, columns=41)。我编写的自定义Dataset 类中的__getitem__ 代码如下所示:

def __getitem__(self, idx):
    x = np.array(self.train.iloc[idx:100, :])
    return x

之所以这样写,是因为我希望 DataLoader 一次处理形状为 (100, 41) 的输入实例,而我们有 32 个这样的单个实例。

但是,我注意到与我最初的看法相反,DataLoader 传递给函数的idx 参数不是连续的(这很重要,因为我的数据是时间序列数据)。例如,打印这些值给了我这样的结果:

idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067

这是正常行为吗?我发布这个问题是因为返回的数据批次不是我最初想要的。

我在使用DataLoader之前对数据进行预处理的原始逻辑是:

    txtcsv 文件中读取数据。 计算数据中有多少批次并相应地对数据进行切片。例如,由于一个输入实例的形状为 (100, 41),其中 32 个构成一个 minibatch,因此我们通常最终会得到大约 100 个左右的批次,并相应地重塑数据。 一个输入的形状为(32, 100, 41)

我不确定我应该如何处理 DataLoader 挂钩方法。非常感谢任何提示或建议。提前致谢。

【问题讨论】:

你能详细说明你的2吗? “我们通常最终得到大约 100 个”你的意思是你的数据集有 32*100 样本吗? 嗨。不,我的意思是模型的一个输入是(100, 40) 的形状,其中有 32 个形成一个小批量。 @Seankala 我试图引导您完成 DataLoader 代码。让我知道这是否有帮助。 @Berriel 是的,它帮助很大。非常感谢您花时间和精力进行详细解释! 【参考方案1】:

定义 idx 的是samplerbatch_sampler,如您所见here(开源项目是您的朋友)。在这个code(和注释/文档字符串)中,您可以看到samplerbatch_sampler 之间的区别。如果您查看here,您会看到索引是如何选择的:

def __next__(self):
    index = self._next_index()

# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
    return next(self._sampler_iter)

# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)

# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler

注意这是_SingleProcessDataLoaderIter的实现;你可以找到_MultiProcessingDataLoaderIterhere(ofc,使用哪一个取决于num_workers的值,你可以看到here)。回到采样器,假设您的数据集不是 _DatasetKind.Iterable 并且您没有提供自定义采样器,这意味着您正在使用 (dataloader.py#L212-L215):

if shuffle:
    sampler = RandomSampler(dataset)
else:
    sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

我们来看看how the default BatchSampler builds a batch:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

非常简单:它从采样器获取索引,直到达到所需的 batch_size。

现在的问题是“__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?”可以通过查看每个默认采样器的工作方式来回答。

SequentialSampler(这是完整的实现——非常简单,不是吗?):
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)
RandomSampler(我们只看__iter__的实现):
def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

因此,由于您没有提供任何代码,我们只能假设:

    您在 DataLoader 中使用shuffle=True 您正在使用自定义采样器 您的数据集是_DatasetKind.Iterable

【讨论】:

一个绝妙的答案!

以上是关于__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?的主要内容,如果未能解决你的问题,请参考以下文章

Python 类特殊方法__getitem__如何使用?

python中如何实现__getitem__ dunder方法获取属性值?

子类化 numpy ndarray 时,如何正确修改 __getitem__?

TypeError:“NoneType”对象没有属性“__getitem__”

getitem

TypeError 'x' 对象没有属性 '__getitem__'