如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?
Posted
技术标签:
【中文标题】如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?【英文标题】:How to ensure that a batch contains samples from all workers with PyTorch's DataLoader? 【发布时间】:2020-01-03 20:35:31 【问题描述】:我想知道如何在 PyTorch 中使用 torch.utils.data.DataLoader
,尤其是在多工人的情况下。
我发现DataLoader
的一批输出总是来自一个工人。
我希望 DataLoader 中有一个队列,它存储来自所有工作人员的数据,并且 DataLoader 将它们打乱在队列中以输出随机批处理数据。我认为这就是 Tensorflow 中tf.data.Dataset
的方式。
我们可以在 PyTorch 中实现类似的功能吗?我想通过使用多工作人员从大型序列化文件(如Tfrecord
)加载数据集。在这种情况下,在一批中混合源文件,也就是混合worker的源,就很重要了。
请参考以下代码:
import random
import time
import torch
class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50
def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()
time.sleep(random.uniform(0, 1))
print("[]:".format(info.id, idx))
return idx, info.id
if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)
输出:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...
这里,[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
中的[0, 1, 2, 3, 4]
和[0, 0, 0, 0, 0]
表示该批次包含来自worker id 0
的索引0-th 到4-th 数据。
注意shuffle=True
并不能解决这个问题,它只会改变数据的索引。
在这种情况下,我想得到一个像:[tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])]
这样的批次。
【问题讨论】:
【参考方案1】:我已经实现了一些简单的方法来解决类似的问题,我将大型视频文件作为训练数据,每个工作人员负责加载和预处理单个文件,然后从中产生样本。问题在于,正如 OP 所描述的,使用 Pytorch 的默认数据加载机制,每个批次仅包含来自单个视频文件的样本。
首先,让我们回顾一下这个问题。在这个简化的代码示例中,每个工作人员都会产生一个包含其零索引工作人员 ID 的张量。批量大小为 32 和 4 个工人,我们希望每个批次包含 8 个零、8 个一、8 个二和 8 个三。
from collections import defaultdict
import torch as T
import torch.utils.data as tdata
class Dataset(tdata.IterableDataset):
def __init__(self, batch_size: int):
self._bs = batch_size
def __iter__(self):
worker_info = tdata.get_worker_info()
if not worker_info:
raise NotImplementedError('Not implemented for num_workers=0')
for _ in range(self._bs):
yield T.tensor([worker_info.id])
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers)
for batch in loader:
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
而是打印代码:
0: 32
1: 32
2: 32
3: 32
这意味着第一批仅包含来自工人 0 的样本,第二批仅包含来自工人 1 的样本,等等。为了解决这个问题,我们将在 DataLoader
中设置批量大小为 batch_size // num_workers
并在DataLoader
为我们的批次汇集每个工人的样本:
def pooled_batches(loader):
loader_it = iter(loader)
while True:
samples = []
for _ in range(loader.num_workers):
try:
samples.append(next(loader_it))
except StopIteration:
pass
if len(samples) == 0:
break
else:
yield T.cat(samples, dim=0)
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
batch_size=per_worker,
num_workers=num_workers)
for batch in pooled_batches(loader):
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
代码现在打印出来了
0: 8, 1: 8, 2: 8, 3: 8
0: 8, 1: 8, 2: 8, 3: 8
0: 8, 1: 8, 2: 8, 3: 8
0: 8, 1: 8, 2: 8, 3: 8
正如预期的那样。
【讨论】:
【参考方案2】:请注意,指定了 batch_size 的 multi-worker DataLoader
将并行加载多个批次,因此基本上一个批次始终来自一个 worker。但是,通过执行以下操作,我已经实现了接近您要求的目标:
将批量大小设为 1,因此每个工人一次只能产生一个样本
编写一个遍历 DataLoader 的后台进程,一次获取 1 个样本并将其插入队列。这样就可以在队列中以不同的顺序排列样本,而不是使用特定于工人的批次
有一个批处理机制,例如 collate_fn
,它从队列中获取与您的批处理大小相等的样本并将其提供给模型
如果您想更具体地批量创建,比如从特定工作人员那里挑选特定样本,您可以有多个队列。应修改您的整理过程以考虑多个队列并从中进行选择。但我怀疑是否需要这种特殊性。
【讨论】:
感谢您的回答,解决了我的问题。我会考虑实现一种嵌套的Dataset
类,它内部有一个批量大小为 1 的 DataLoader
。以上是关于如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch中如何使用DataLoader对数据集进行批训练
Pytorch中如何使用DataLoader对数据集进行批训练
PyTorch DataLoader 将批次作为列表返回,批次作为唯一条目。如何从我的 DataLoader 获取张量的最佳方式
__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?