具有多个工作人员的可迭代 pytorch 数据集

Posted

技术标签:

【中文标题】具有多个工作人员的可迭代 pytorch 数据集【英文标题】:Iterable pytorch dataset with multiple workers 【发布时间】:2021-12-15 02:07:03 【问题描述】:

所以我有一个比我的 ram 内存大的文本文件,我想在 PyTorch 中创建一个数据集,它可以逐行读取,所以我不必一次将它全部加载到内存中。我发现 pytorch IterableDataset 作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,它将创建重复记录。让我给你举个例子:

有一个testfile.txt 包含:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

定义一个 IterableDataset:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr

我们现在可以测试它了:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
    print(X,y)

它输出:



('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)

没错。但是,如果我将工人数量更改为 2,则输出变为

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

这是不正确的,因为在数据加载器中为每个工作人员创建每个样本的副本。

有没有办法用 pytorch 解决这个问题?因此,可以创建一个数据加载器来不加载内存中的所有文件,并支持多个工作人员。

【问题讨论】:

【参考方案1】:

您可以使用torch.utils.data.get_worker_info 实用程序访问Dataset__iter__ 函数中的worker 标识符。这意味着您可以单步执行迭代器并根据工作人员 id 添加偏移量。您可以使用 itertools.islice 包装一个迭代器,这允许您步进 start 索引以及 step

这是一个最小的例子:

class DS(IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        uid = torch.utils.data.get_worker_info().id
        itr = islice(range(10), uid, None, self.batch_size)

即使我们使用num_workers > 1,循环通过数据加载器也会产生唯一的实例:

>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
...     print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])

你可以这样做:

    def __iter__(self):
        # create an iterator
        file_itr = open(self.filename)

        # map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
    
        # wrap the iterator
        step_itr = islice(mapped_itr, uid, None, self.batch_size)

        return step_itr

【讨论】:

如何将此逻辑添加到文件读取中?我了解您如何使用已经定义的列表来执行此操作,但是如何在旅途中阅读文件呢? 我真的不明白,您的可迭代数据集不是已经实现并使用num_workers=1了吗? 是的,但是您在创建迭代器时基于uid 进行索引,而不是在创建迭代器之后。这就是为什么我不遵循如何将它放在我当前的 IterableDataset 之上的原因 __iter__ 函数由数据加载器调用,对于每个工人数据加载器第一次循环。您是否尝试在 __iter__ 函数中返回 iter(list(mapped_itr)[uid::self.batch_size]) 还需要一个更改,而不是切片时 self.batch_size 应该是工人数。【参考方案2】:

所以我在火炬讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3 中找到了答案,他们指出我应该使用工人信息连续切片到批量大小。

新数据集如下所示:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr

特别感谢@Ivan,他也指出了切片解决方案。

如果有两个工人,它只返回与 1 个工人相同的数据

【讨论】:

最好进行编辑...无论如何,请考虑支持原始答案。

以上是关于具有多个工作人员的可迭代 pytorch 数据集的主要内容,如果未能解决你的问题,请参考以下文章

具有多个值的张量的布尔值在 Pytorch 中不明确

TFS 中具有多个版本的迭代

在迭代期间可以变异的可迭代集合

具有共享内存的 Pytorch 多处理导致 matmul 慢 30 倍(只有两个进程)

从甚至可迭代的[重复]中获取2元组的可迭代

如何在 PyTorch Lightning 中编写多个训练设置