pytorch-lightning train_dataloader 用完数据

Posted

技术标签:

【中文标题】pytorch-lightning train_dataloader 用完数据【英文标题】:pythorch-lightning train_dataloader runs out of data 【发布时间】:2020-09-12 08:39:51 【问题描述】:

我开始使用 pytorch-lightning,但遇到了自定义数据加载器的问题:

我使用自己的数据集和通用的 torch.utils.data.DataLoader。基本上,数据集采用路径并加载与数据加载器加载的给定索引对应的数据。

def train_dataloader(self):
    train_set = TextKeypointsDataset(parameters...)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
    return train_loader 

当我使用 pytorch-lightning 模块 train_dataloadertraining_step 时,一切运行良好。当我添加val_dataloadervalidation_step 时,我遇到了这个错误:

Epoch 1:  45%|████▌     | 10/22 [00:02<00:03,  3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)

在这种情况下,我的数据集非常小(用于测试功能),只有 84 个样本,我的批量大小为 8。用于训练和验证的数据集长度相同(仅用于测试目的)。

所以总共有 84 * 2 = 168 和 168 / 8 (batchsize) = 21,大致就是上面显示的总步数 (22)。这意味着在训练数据集上运行 10 次 (10 * 8 = 80) 后,加载器期望新的完整样本为 8,但由于只有 84 个样本,我得到一个错误(至少这是我目前的理解)。

我在自己的实现中遇到了类似的问题(不使用 pytorch-lighntning)并使用此模式来解决它。基本上,当数据用完时,我正在重置迭代器:

try:
    data = next(data_iterator)
    source_tensor = data[0]
    target_tensor = data[1]

except StopIteration:  # reinitialize data loader if num_iteration > amount of data
    data_iterator = iter(data_loader)

现在好像我面临着类似的事情?当我的 training_dataloader 数据不足时,我不知道如何在 pytorch-lightning 中重置/重新初始化数据加载器。我想一定有另一种我不熟悉的复杂方式。谢谢

【问题讨论】:

实现自己的Dataset 是相当标准的,但定义自定义DataLoader 可能是一个错误,因为它在后端执行各种复杂的事情(多线程等)。在最极端的情况下,您应该能够定义自己的 Sampler 和可能的 collate_fn(如有必要),这两者都将在构建时提供给您的 DataLoader 我编辑了我的问题以使其更清楚。我使用自己的数据集,但不是自定义数据加载器 【参考方案1】:

解决办法是:

我使用source_tensor = source_tensor.view(-1, self.batch_size, self.input_size),后来导致一些错误,现在我使用source_tensor = source_tensor.permute(1, 0, 2),解决了问题。

【讨论】:

以上是关于pytorch-lightning train_dataloader 用完数据的主要内容,如果未能解决你的问题,请参考以下文章

如何禁用 PyTorch-Lightning 记录器的日志记录?

无法从 Pytorch-Lightning 中的检查点加载模型

pytorch-lightning入门—— 初了解

PyTorch-lightning 模型在第一个 epoch 后内存不足

使用 pytorch-lightning 进行简单预测的示例

使用 pytorch-lightning 实现 Network in Network CNN 模型