如何从 pytorch 数据加载器中获取批迭代的总数?
Posted
技术标签:
【中文标题】如何从 pytorch 数据加载器中获取批迭代的总数?【英文标题】:How to get the total number of batch iteration from pytorch dataloader? 【发布时间】:2021-01-03 21:49:35 【问题描述】:我有一个问题,如何从 pytorch 数据加载器获取批迭代的总数?
以下是训练的常用代码
for i, batch in enumerate(dataloader):
那么,有什么方法可以得到“for循环”的总迭代次数吗?
在我的NLP问题中,总的迭代次数与int(n_train_samples/batch_size)不同...
例如,如果我只截断训练数据 10,000 个样本并将批量大小设置为 1024,那么在我的 NLP 问题中会发生 363 次迭代。
我想知道如何获得“for-loop”中的总迭代次数。
谢谢。
【问题讨论】:
【参考方案1】:创建数据加载器时还有一个附加参数。它被称为drop_last
。
如果drop_last=True
,则长度为number_of_training_examples // batch_size
。
如果drop_last=False
可能是number_of_training_examples // batch_size +1
。
BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)
对于预定义的数据集,您可能会得到如下示例的数量:
# number of examples
len(dl_train.dataset)
dataloader 中的正确批次数始终为:
# number of batches
len(dl_train)
【讨论】:
【参考方案2】:len(dataloader)
返回批次总数。这取决于您的数据集的__len__
函数,因此请确保设置正确。
【讨论】:
@HyunseungKim 你的数据集中有__len__
函数吗?数据加载器的 len 函数取决于数据集的 __len__
函数。我每天都使用它,所以我确信它有效:)
很抱歉,我认为有一些错误。实际上,我出于某种目的尝试修改此代码(github.com/SamLynnEvans/Transformer)。在脚本“Batch.py”中,有一个名为“MyIterator”的类,它返回train_iter。但是,我不能确定它是否是数据加载器……我必须进一步检查。
现在我明白你所说的“这取决于数据集的 __ len __ 函数,因此请确保设置正确”。 __ len__ 必须写。以上是关于如何从 pytorch 数据加载器中获取批迭代的总数?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch中如何使用DataLoader对数据集进行批训练