pytorch中的数据导入之DataLoader和Dataset的使用介绍
Posted 非晚非晚
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中的数据导入之DataLoader和Dataset的使用介绍相关的知识,希望对你有一定的参考价值。
文章目录
在使用Pytorch构建和训练模型的过程中,经常需要把原始数据(图片、文本等)转换为张量的格式。对于小数据集,我们可以手动导入,但是在深度学习中,数据集往往是比较大的,这时pytorch的数据导入功能便发挥了作用,Pytorch导入数据主要依靠
torch.utils.data.DataLoader
和
torch.utils.data.Dataset
这两个类来完成。
torch.utils.data.Dataset
:这是一个抽象类
,所以我们需要对其进行派生,从而使用其派生类来创建数据集
。最主要的两个函数实现为__Len__
和__getitem
。
__init__
:可以在这里设置加载的data和label。__Len__
:获取数据集大小__getitem
:根据索引获取一条训练的数据和标签。
torch.utils.data.DataLoader
:接收torch.utils.data.Dataset作为输入,得到DataLoader,它是一个迭代器
,方便我们去多线程地读取数据,并且可以实现batch以及shuffle的读取等。
pytorch 的数据加载到模型的操作顺序如下:
(看完本文章,再回过头看这部分会更清晰
):
- 创建一个 Dataset 对象
- 创建一个 DataLoader 对象
- 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....
1. 构建数据集-torch.utils.data.Dataset
torch.utils.data.Dataset可通过两种方式生成,一种是通过内置的下载功能,另一种便是自己实现
。下载功能需要借助一些其他的包,比如torchvision,下载CIFAR10数据格式大致如下:
import torch
import torchvision
cf10_data = torchvision.datasets.CIFAR10('dataset/cifar/', download=True)
下面我们主要介绍怎么创建一个属于自己的数据集。
任何自定义的数据集都要继承自torch.utils.data.Dataset,然后重写两个函数:__len__(self)
和__getitem__(self, idx)
。
下面是一个简单的自定义小型数据集,以期能够理解它的创建方式。注意下面的示例中并没有通过__init__
传入data和label,而是在内部创建的。
import torch
from torch.utils.data import Dataset
class myDataset(Dataset):
def __init__(self):
#创建5*2的数据集
self.data = torch.tensor([[1,2],[3,4],[2,1],[3,4],[4,5]])
#5个数据的标签
self.label = torch.tensor([0,1,0,1,2])
#根据索引获取data和label
def __getitem__(self,index):
return self.data[index], self.label[index] #以元组的形式返回
#获取数据集的大小
def __len__(self):
return len(self.data)
data = myDataset()
print(f'data size is : len(data)')
print(data[1]) #获取索引为1的data和label
输出:
data size is : 5
(tensor([3, 4]), tensor(1))
2. 数据载入-torch.utils.data.DataLoader
torch.utils.data.Dataset
通过__getitem__
获取单个数据,如果希望获取批量数据、shuffle或者其它的一些操作,那么就要由torch.utils.data.DataLoader来实现了,它的实现形式如下:
data.DataLoader(
dataset,
batch_size = 50,
shuffle = False,
sampler=None,
batch_sampler = None,
num_workers = 0,
collate_fn =
pin_memory = False,
drop_last = False,
timeout = 0,
worker_init_fn = None,
)
dataset
:待传入的数据集,也就是上面自己实现的myData。batch_size
:每个batch有多少个样本shuffle
:代表数据会不会被随机打乱,在每个epoch开始的时候,对数据进行重新排序
。sampler
:自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为Falsebatch_sampler
:类似于sampler,不过返回的是一个迷你批次的数据索引。num_workers
:是数据载入器使用的进程数目,默认为0。collate_fn
:用于自定义sample 如何形成 batch sample 的函数。pin_memory
:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中。drop_last
:如果设置为true,那么最后的batch的大小如果小于batch_size,那么则会丢弃。timeout
:如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0。worker_init_fn
:它决定了每个数据载入的子进程开始时运行的函数
按照上面的Dataset,使用DataLoader加载数据示例如下,因为设置了drop_last = True
,所以最后一个batch会被丢弃。
from torch.utils.data import DataLoader
data = myDataset()
my_loader = DataLoader(data,batch_size=2,shuffle=False,num_workers = 0,drop_last=True)
for step,train_data in enumerate(my_loader):
Data,Label = train_data
print("step:",step)
print("data:",Data)
print("Label:",Label)
输出:
step: 0
data: tensor([[1, 2],
[3, 4]])
Label: tensor([0, 1])
step: 1
data: tensor([[2, 1],
[3, 4]])
Label: tensor([0, 1])
3. 把数据放入GPU中
在Dataset和DataLoader的地方都可以实现把数据放入GPU,下面分别进行介绍。
- Dataset阶段把数据放入GPU
如果在此阶段把数据放入GPU,则此阶段必须把num_workers设置为0,要不然会报错。此阶段的操作需要在__getitem__
中实现,实现过程大致如下。
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
if torch.cuda.is_available():
data = data.cuda()
label = label.cuda()
return data, label
- DataLoader阶段把数据放入GPU
这种实现方式就没有特别需要注意的地方,直接把tensor放入GPU即可,所以推荐使用这种实现方式
,如下所示。
data = myDataset()
my_loader = DataLoader(data,2,shuffle=False,num_workers = 0,drop_last=True)
for step,train_data in enumerate(my_loader):
Data,Label = train_data
#把数据放在GPU中
if torch.cuda.is_available():
data = data.cuda()
label = label.cuda()
print("step:",step)
print("data:",Data)
print("Label:",Label)
以上是关于pytorch中的数据导入之DataLoader和Dataset的使用介绍的主要内容,如果未能解决你的问题,请参考以下文章
深度之眼PyTorch训练营第二期 ---5Dataloader与Dataset
PyTorch源码解读之torch.utils.data.DataLoader(转)
PyTorch之torch.utils.data.DataLoader解读