Pytorch加载数据集的方式总结
Posted 咕噜咕噜冰阔落
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch加载数据集的方式总结相关的知识,希望对你有一定的参考价值。
Pytorch加载数据集的方式总结
- 一、自己重写定义(Dataset、DataLoader)
- 二、用Pytorch自带的类(ImageFolder、datasets、DataLoader)
- 三、总结
- 四、transforms变换讲解
- 五、DataLoader的补充
在用Pytorch加载数据集时,看GitHub上的代码经常会用到ImageFolder、DataLoader等一系列方法,而这些方法又是来自于torchvision、torch.utils.data。除加载数据集外,还有torchvision中的transforms对数据集预处理…等等等等。这个data,那个dataset…这一系列下来,不加注意的话实在有点打脑壳。看别人的代码加载数据集挺简单,但是自己用的时候,尤其是加载自己所制作的数据集的时候,就会茫然无措。别无他法,抱着硬啃的心态,查阅了其他博文,通过代码实验,终于是理清楚了思路。
Pytorch加载数据集可以分两种大的情况:一、自己重写定义; 二、用Pytorch自带的类。第二种里面又有多种不同的方法(datasets、 ImageFolder等),但这些方法都有相同的处理规律。我理解的,无论是哪种情况,加载数据集都需要构造数据加载器和数据装载器(后者生成的是可迭代的数据)。现将这两种情况一一说明。
一、自己重写定义(Dataset、DataLoader)
目前我们有自己制作的数据以及数据标签,但是有时候感觉不太适合直接用Pytorch自带加载数据集的方法。我们可以自己来重写定义一个类,这个类继承于 torch.utils.data.Dataset,同时我们需要重写这个类里面的两个方法 _ getitem__ () 和__ len()__函数。
如下所示。这两种方法如何构造以及具体的细节可以查看其他的博客。len方法必须返回数据的长度,getitem方法必须返回数据以及标签。
import torch
import numpy as np
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)
通过上述的程序,我们构造了一个数据加载器torch_data,但是还是不能直接传入网络中。接下来需要构造数据装载器,产生可迭代的数据,再传入网络中。DataLoader类完成这个工作。
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
参数解释:
1.dataset : 加载torch.utils.data.Dataset对象数据
2.batch_size : 每个batch的大小,将我们的数据分批输入到网络中
3.shuffle : 是否对数据进行打乱
4.drop_last : 是否对无法整除的最后一个datasize进行丢弃
5.num_workers : 表示加载的时候子进程数
结合我们自己定义的加载数据集类,可以如下使用。后面将data和label传入我们定义的模型中。
...
torch_data = GetLoader(source_data, source_label)
from torch.utils.data import DataLoader
datas = DataLoader(torch_data, batch_size = 4, shuffle = True, drop_last = False, num_workers = 2)
for i, (data, label) in enumerate(datas):
# i表示第几个batch, data表示batch_size个原始的数据,label代表batch_size个数据的标签
print("第 个Batch \\n".format(i, data))
二、用Pytorch自带的类(ImageFolder、datasets、DataLoader)
2.1 加载自己的数据集
2.1.1 ImageFolder介绍
和第一种情况不一样,我们不需要在代码上自己定义数据集类了,而是将数据集按照一定的格式摆放,调用ImageFolder类即可。这种是在调用Pytorch内部的API,所以我们自己的数据集得需要按照API内部所规定的存放格式。torchvision.datasets.ImageFolder 要求数据集按照如下方式组织。根目录 root 下存储的是类别文件夹(如cat,dog),每个类别文件夹下存储相应类别的图像(如xxx.png)
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
torchvision.datasets.ImageFolder有以下参数:
dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=<function default_loader>,
is_valid_file=None)
参数解释:
1.root:根目录,在root目录下,应该有不同类别的子文件夹;
|--data(root)
|--train
|--cat
|--dog
|--valid
|--cat
|--dog
2.transform:对图片进行预处理的操作,原始图像作为一个输入,返回的是transform变换后的图片;
3.target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对target不做任何转换,返回的顺序索引 0,1, 2…
4.loader:表示数据集加载方式,通常默认加载方式即可;
5.is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
作为torchvision.datasets.ImageFolder的返回,会有以下三种属性:
(1)self.classes:用一个 list 保存类别名称
(2)self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
(3)self.imgs:保存(img_path, class) tuple的list
以猫狗类别举例,各属性输出如下所示:
print(dataset.classes) #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
'''
输出:
['cat', 'dog']
'cat': 0, 'dog': 1
[('./data/train\\\\cat\\\\1.jpg', 0),
('./data/train\\\\cat\\\\2.jpg', 0),
('./data/train\\\\dog\\\\1.jpg', 1),
('./data/train\\\\dog\\\\2.jpg', 1)]
'''
2.2.2 ImageFolder加载数据集完整例子
# 1.导入相关数据库
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
# 2.定义图片转换方式
train_transforms = torchvision.transforms.Compose([
transforms.RandomResizedCrop(400),
transforms.ToTensor()
])
# 3. 定义地址
path = os.path.join(os.getcwd(), 'data', 'train')
# 4. 将文件夹数据导入
dataset = ImageFolder(root=path, transform=train_transforms)
和第一种情况自己重写定义一样,上述的代码仅仅完成了数据加载器的定义。这样是不能直接传入网络中进行训练的,需要再构造一个可迭代的数据装载器。DataLoader类的使用方式上文中有详细介绍。
# 5. 将文件夹数据导入
train_loader = torch.utils.data.DataLoader(dataset,
batch_size = batch_size, shuffle=True,
num_workers = 2)
# 6. 传入网络进行训练
for epoch in range(epochs):
train_bar = tqdm(train_loader, file = sys.stdout)
for step, data in enumerate(train_bar):
...
2.2 加载常见的数据集
有些数据集是公共的,比如常见的MNIST,CIFAR10,SVHN等等。这些数据集在Pytorch中可以通过代码就可以下载、加载。如下代码所示。用torchvision中的datasets类下载数据集,并还是结合DataLoader来构建可直接传入网络的数据装载器。
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def dataloader(dataset, input_size, batch_size, split='train'):
transform = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
if dataset == 'mnist':
data_loader = DataLoader(
datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'fashion-mnist':
data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'cifar10':
data_loader = DataLoader(
datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'svhn':
data_loader = DataLoader(
datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'stl10':
data_loader = DataLoader(
datasets.STL10('data/stl10', split=split, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'lsun-bed':
data_loader = DataLoader(
datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
batch_size=batch_size, shuffle=True)
return data_loader
三、总结
至于觉得加载数据集比较难的很大的原因,我感觉是Dataset、datasets、DataLoader以及torch.utils.data、torchvision种类太多,有点混乱。上面的梳理,我的理解是无论是哪种方式,终端还是需要DataLoader整合。作为加载数据集的前端,用自己定义的、用ImageFolder的、还是用datasets加载常用数据集,都是在构造数据加载器,而且构造起来也并不复杂。梳理清晰后,相信对Pytorch加载数据集有了更进一步的理解。
四、transforms变换讲解
torchvision.transforms是Pytorch中的图像预处理包。一般定义在加载数据集之前,用transforms中的Compose类把多个步骤整合到一起,而这些步骤是transforms中的函数。
transforms中的函数有这些:
函数 | 含义 |
---|---|
transforms.Resize | 把给定的图片resize到given size |
transforms.Normalize | 用均值和标准差归一化张量图像 |
transforms.Totensor | 可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ; <br /另外原始数据的shape是(H x W x C),通过transforms.ToTensor() 后shape会变为(C x H x W) |
transforms.RandomGrayscale | 将图像以一定的概率转换为灰度图像 |
transforms.ColorJitter | 随机改变图像的亮度对比度和饱和度 |
transforms.Centercrop | 在图片的中间区域进行裁剪 |
transforms.RandomCrop | 在一个随机的位置进行裁剪 |
transforms.FiceCrop | 把图像裁剪为四个角和一个中心 |
transforms.RandomResizedCrop | 将PIL图像裁剪成任意大小和纵横比 |
transforms.ToPILImage | convert a tensor to PIL image |
transforms.RandomHorizontalFlip | 以0.5的概率水平翻转给定的PIL图像 |
transforms.RandomVerticalFlip | 以0.5的概率竖直翻转给定的PIL图像 |
transforms.Grayscale | 将图像转换为灰度图像 |
不同函数对应有不同的属性,用transforms.Compose将不同的操作整合在一起,如下所示。
transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
五、DataLoader的补充
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
用下面的例子测试:
"""
批训练,把数据变成一小批一小批数据进行训练。
DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:, batch_x:, batch_y:".format(step, batch_x, batch_y))
if __name__ == '__main__':
show_batch()
结果如下所示。仔细观察:
每一个step,batch_x是不会重合的,batch_y里面的值也是不会重合的(第一个step中,batch_x:tensor([ 3., 10., 6., 2., 8.]);第二个step中batch_x:tensor([5., 9., 7., 4., 1.])),说明DataLoader将数据打乱后,每次选用其中的Batch_size个数据且不会重复;
其二,batch_x 和 batch_y对应的索引之和相等,这说明DataLoader对图像和标签打乱顺序时,同时按照某一规律打乱,并不会造成标签和图像出现不对应的情况。
其三,在不同的epoch之间,每次数据也是不同的,说明DataLoader每次被调用时,都会重新打乱一次。
steop:0, batch_x:tensor([ 3., 10., 6., 2., 8.]), batch_y:tensor([8., 1., 5., 9., 3.])
steop:1, batch_x:tensor([5., 9., 7., 4., 1.]), batch_y:tensor([ 6., 2., 4., 7., 10.])
steop:0, batch_x:tensor([8., 3., 1., 2., 9.]), batch_y:tensor([ 3., 8., 10., 9., 2.])
steop:1, batch_x:tensor([10., 5., 4., 7., 6.]), batch_y:tensor([1., 6., 7., 4., 5.])
steop:0, batch_x:tensor([5., 8., 4., 3., 7.]), batch_y:tensor([6., 3., 7., 8., 4.])
steop:1, batch_x:tensor([ 2., 10., 6., 9., 1.]), batch_y:tensor([ 9., 1., 5., 2., 10.])
PyTorch 数据加载器显示字符串数据集的奇怪行为
【中文标题】PyTorch 数据加载器显示字符串数据集的奇怪行为【英文标题】:PyTorch dataloader shows odd behavior with string dataset 【发布时间】:2021-03-01 04:26:38 【问题描述】:我正在处理一个 NLP 问题并且正在使用 PyTorch。 由于某种原因,我的数据加载器返回了格式错误的批次。我输入了包含句子和整数标签的数据。 句子可以是句子列表或标记列表列表。稍后我将在下游组件中将标记转换为整数。
list_labels = [ 0, 1, 0]
# List of sentences.
list_sentences = [ 'the movie is terrible',
'The Film was great.',
'It was just awful.']
# Or list of list of tokens.
list_sentences = [['the', 'movie', 'is', 'terrible'],
['The', 'Film', 'was', 'great.'],
['It', 'was', 'just', 'awful.']]
我创建了以下自定义数据集:
import torch
from torch.utils.data import DataLoader, Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, sentences, labels):
self.sentences = sentences
self.labels = labels
def __getitem__(self, i):
result =
result['sentences'] = self.sentences[i]
result['label'] = self.labels[i]
return result
def __len__(self):
return len(self.labels)
当我以句子列表的形式提供输入时,数据加载器正确返回成批的完整句子。注意batch_size=2
:
list_sentences = [ 'the movie is terrible', 'The Film was great.', 'It was just awful.']
list_labels = [ 0, 1, 0]
dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)
batch = next(iter(dataloader))
print(batch)
# 'sentences': ['the movie is terrible', 'The Film was great.'], <-- Great! 2 sentences in batch!
# 'label': tensor([0, 1])
批次正确包含两个句子和两个标签,因为batch_size=2
。
但是,当我将句子输入为标记列表的预标记列表时,我得到了奇怪的结果:
list_sentences = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.'], ['It', 'was', 'just', 'awful.']]
list_labels = [ 0, 1, 0]
dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)
batch = next(iter(dataloader))
print(batch)
# 'sentences': [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')], <-- WHAT?
# 'label': tensor([0, 1])
请注意,这批的 sentences
是一个包含单词对元组的单个列表。 我原以为sentences
是两个列表的列表,如下所示:
'sentences': [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']
发生了什么事?
【问题讨论】:
我也遇到了这个问题。这似乎是一个真正的问题——pytorch 应该能够整理成批的字符串。我可以看到很多情况下您可能希望在数据加载器步骤之后处理字符串。 【参考方案1】:另一种解决方案是将字符串编码为字节并在您的Dataset
中,然后在您的前向传递中对其进行解码。如果您想包含元数据字符串(例如数据来自的文件路径),但实际上不需要将数据传递到模型中,这很有用。
例如:
class MyDataset(torch.utils.data.Dataset):
def __next__(self):
return np.array("this is a sentence").bytes()
然后在你的前向传球中你会这样做:
sentences: List[str] = []
for sentence in batch:
sentences.append(sentence.decode("ascii"))
【讨论】:
【参考方案2】:这种行为是因为默认collate_fn
在必须整理list
s 时执行following(['sentences']
就是这种情况):
# [...]
elif isinstance(elem, container_abcs.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
之所以会出现“问题”,是因为在最后两行中,它将递归调用 zip(*batch)
,而批处理是 container_abcs.Sequence
(和 list
是),而 zip
的行为是这样的。
如你所见:
batch = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']]
list(zip(*batch))
# [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')]
除了实现一个新的整理器并将其传递给DataLoader(..., collate_fn=mycollator)
之外,我没有在您的情况下看到解决方法。例如,一个简单的丑可能是:
def mycollator(batch):
assert all('sentences' in x for x in batch)
assert all('label' in x for x in batch)
return
'sentences': [x['sentences'] for x in batch],
'label': torch.tensor([x['label'] for x in batch])
【讨论】:
谢谢。我应该像你一样深入研究批处理生成器。 我也应该认识到,当您在两个列表的同一索引处看到成对的事物时,例如 ('the', 'The')
,它可能是 zip()
的输出。以上是关于Pytorch加载数据集的方式总结的主要内容,如果未能解决你的问题,请参考以下文章