pytorch中data.DataLoader用法

Posted Shuxuan1

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中data.DataLoader用法相关的知识,希望对你有一定的参考价值。

data.DataLoader

pytorch中data.DataLoader类实现数据的迭代。参数如下:

dataset:(数据类型 dataset)

输入的数据类型,这里是原始数据的输入。PyTorch内也有这种数据结构。

batch_size:(数据类型 int)

批训练数据量的大小,根据具体情况设置即可(默认:1)。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。每次是随机读取大小为batch_size。如果dataset中的数据个数不是batch_size的整数倍,这最后一次把剩余的数据全部输出。若想把剩下的不足batch size个的数据丢弃,则将drop_last设置为True,会将多出来不足一个batch的数据丢弃。

shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。

sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

num_workers:(数据类型 Int)

工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。

pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

使用示例

for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader):
	pass

batch_idx是使用enumerate迭代时自动添加的索引,在此处代表当前迭代的是第几个batch,每次迭代的数据为batch_size个样本。

如图是打印出来每次迭代的batch_idx和每批数据的size。每次迭代的数据大小为[32,3,288,144]32为batch_size,[3,288,144]为图片数据。

CSDN 社区图书馆,开张营业! 深读计划,写书评领图书福利~

以上是关于pytorch中data.DataLoader用法的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中data.DataLoader用法

PyTorch源码解读之torch.utils.data.DataLoader(转)

PyTorch之torch.utils.data.DataLoader解读

Pytorch学习笔记:数据读取机制(DataLoader与Dataset)

torch.utils.data.DataLoader()详解Pytorch入门手册

pytorch中的数据导入之DataLoader和Dataset的使用介绍