如何在 Dataloader 中使用 Batchsampler

Posted

技术标签:

【中文标题】如何在 Dataloader 中使用 Batchsampler【英文标题】:How to use a Batchsampler within a Dataloader 【发布时间】:2020-08-10 23:22:43 【问题描述】:

我需要在 pytorch DataLoader 中使用 BatchSampler,而不是多次调用数据集的 __getitem__(远程数据集,每个查询都很昂贵)。 我不明白如何将批处理采样器与任何给定的数据集一起使用。

例如

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, idx):
        return self.ddf[idx] --------> This is as expensive as a batch call

    def get_batch(self, batch_idx):
        return self.ddf[batch_idx]

my_loader = DataLoader(MyDataset(remote_ddf), 
           batch_sampler=BatchSampler(Sampler(), batch_size=3))

我不明白的事情是,我如何使用我的 get_batch 函数而不是 __getitem__ 函数,在网上或 Torch 文档中都没有找到任何示例。 编辑: 按照 Szymon Maszke 的回答,这是我尝试过的,但是 \_\_get_item__ 每次调用都会获得一个索引,而不是大小为 batch_size 的列表

class Dataset(Dataset):

    def __init__(self):
       ...

    def __len__(self):
        ...

    def __getitem__(self, batch_idx):  ------> here I get only one index
        return self.wiki_df.loc[batch_idx]


loader = DataLoader(
                dataset=dataset,
                batch_sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )

【问题讨论】:

【参考方案1】:

您不能使用get_batch 代替__getitem__,而且我认为这样做没有意义。

torch.utils.data.BatchSampler 从您的 Sampler() 实例(在本例中为 3)获取索引并将其作为 list 返回,因此可以在您的 MyDataset __getitem__ 方法中使用这些索引(检查 source code ,大多数采样器和与数据相关的实用程序都很容易使用,以备不时之需)。

我假设您的 self.ddf 支持列表切片(例如 self.ddf[[25, 44, 115]] 正确返回值并且只使用一个昂贵的调用)。在这种情况下,只需将 get_batch 切换为 __getitem__ 即可。

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, batch_idx):
        return self.ddf[batch_idx] -> batch_idx is a list

编辑:您必须将batch_sampler 指定为sampler,否则批次将被分成单个索引。这应该没问题:

loader = DataLoader(
    dataset=dataset,
    # This line below!
    sampler=BatchSampler(
        SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
    ),
    num_workers=self.hparams.num_data_workers,
)

【讨论】:

听起来很有趣,但我无法从文档中理解它。数据集的 getitem 听起来像是返回一个样本,在我的例子中是一行。 torch.utils.data.Dataset 是一个相当灵活的结构(至少来自 pytorch 版本 1.4 IIRC)所以 index 可以是任何真正的 AFAIK。如果你使用batch_sampler,它负责创建整批数据。 当然,但是从文档的角度来看,整理功能(聚合)是为您隐式完成的,这意味着 get 得到 k 乘以 1,然后进行聚合。这意味着在 getitem 之后进行 no 聚合 collate_fn 允许您在从批处理返回数据后对其进行“后处理”。您可以从数据集中返回list[Tensor],或者在使用标准采样器时返回list[Tensor],您可以从中创建张量。很好的用例是填充可变长度张量以与 RNN 或类似方法一起使用。虽然我同意DataLoader 可能有点令人困惑。 是的哈哈哈!我现在才明白,自己来回答。谢谢!

以上是关于如何在 Dataloader 中使用 Batchsampler的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 PyTorch DataLoader 进行强化学习?

Pytorch中如何使用DataLoader对数据集进行批训练

Pytorch中如何使用DataLoader对数据集进行批训练

如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?

如何使用 DataLoader 与 Hot Chocolate GraphQL 进行连接

火炬。在 Dataloader 中 pin_memory 是如何工作的?