深入浅出 Dataset 与 DataLoader
Posted Mae_strive
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入浅出 Dataset 与 DataLoader相关的知识,希望对你有一定的参考价值。
文章目录
Dataset & DataLoader
1、官方解释(Google翻译):
处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。
PyTorch 提供了两个数据原语:torch.utils.data.DataLoader
和torch.utils.data.Dataset
允许我们使用预加载的数据集以及我们自己的数据。 Dataset存储样本及其对应的标签,并DataLoader在 周围包裹一个可迭代对象Dataset,以便轻松访问样本。
2、Dataset
是所有开发人员训练、测试使用的所有数据集的一个模板。
Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。
DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。
3、DataLoader
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,我们只需实现Dataset的 __len__方法 和 __getitem__方法 ,就可以轻松构建自己的数据集,并用默认数据管道进行加载。
一、自定义Dataset
自定义 Dataset 类需继承 pytorch官方的DataSet类 还必须实现三个函数:__init__、__len__和__getitem__。
init:初始化(一般需要传入 数据集文件路径,将文件保存到哪个路径,预处理函数)
len:返回数据集的大小
getitem:根据索引,返回样本的特征和标签。
import os.path
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
class MyImageDataset(Dataset):
def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
# annotations_file:文件路径
# data_dir: 将文件保存到哪个路径
self.data_label = pd.read_csv(annotations_file)
self.data_dir = data_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
# 返回数据集总的大小
return len(self.data_label)
def __getitem__(self, item):
data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
image = read_image(data_name)
# 对特征进行预处理
label = self.data_label.iloc[item, 1]
if self.transform:
image = self.transform(image)
# 对标签进行预处理
if self.target_transform:
label = self.target_transform(label)
return image, label
其实我们只需要修改的是annotations_file, data_dir, transform(特征预处理), target_transform(标签预处理) 这四个参数。
Dataset每次只处理一个样本,返回的是一个特征和该特征所对应的标签
二、使用 DataLoaders 为训练准备数据
检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch (每次迭代多少次)重新洗牌以减少模型过度拟合,并使用 Pythonmultiprocessing加速数据检索。
batch_size:一次训练所选取的样本数
shuffle=True: 每个训练周期后对数据进行随机排列
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
三、迭代数据
我们已将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features和train_labels(分别包含batch_size=64特征和标签)。
iter()方法 得到一个迭代器。
next() 方法 依次获得特征和标签。
train_features, train_labels = next(iter(train_dataloader))
以上是关于深入浅出 Dataset 与 DataLoader的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch学习笔记:数据读取机制(DataLoader与Dataset)
深度之眼PyTorch训练营第二期 ---5Dataloader与Dataset
COCO_03 制作COCO格式数据集 dataset 与 dataloader
COCO_03 制作COCO格式数据集 dataset 与 dataloader