Torchvision:对数据进行操作
Posted repinkply
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Torchvision:对数据进行操作相关的知识,希望对你有一定的参考价值。
Torchvision:数据读取,训练开始的第一步
如果将模型看作一辆汽车,那么它的开发过程就可以看作是一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择以及一些辅助的工具等。未来你将尝试构建自己的豪华汽车,或者站在巨人的肩膀上对前人的作品进行优化。
试想一下,如果你对这些基础环节所使用的方法都不清楚,你还能很好地进行下去吗?所以通过这个模块,我们的目标是先把基础打好。通过这模块的学习,对于 PyTorch 都为我们提供了哪些丰富的 API,你就会了然于胸了。
Torchvision 是一个和 PyTorch 配合使用的 Python 包,包含很多图像处理的工具。我们先从数据处理入手,开始 PyTorch 的学习的第一步。我会先介绍 Torchvision 的常用数据集及其读取方法,在后面的文章里,我再带你了解常用的图像处理方法与Torchvision 其它有趣的功能。
PyTorch 中的数据读取
训练开始的第一步,首先就是数据读取。PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。下面我们分别来看下 Dataset 类与 DataLoader 类。
Dataset 类
PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。
其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:
__init__():构造函数,可自定义数据读取方法以及进行数据预处理;
__len__():返回数据集大小;
__getitem__():索引数据集中的某一个数据。
光看原理不容易理解,下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个Tensor 类型的数据集。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
结合代码可以看到,我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入Tensor 类型的数据与标签;在 __len__ 函数中,直接返回 Tensor 的大小;在__getitem__ 函数中返回索引的数据与标签。
下面,我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]:
(tensor([ 0.4931, -0.0697,
0.4171]), tensor(0))
'''
DataLoader 类
在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。
DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。
DataLoader 类的调用方式如下:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]:
(tensor([ 0.4931, -0.0697,
0.4171]), tensor(0))
'''
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
batch_size=2, # 输出的batch大小
shuffle=True, # 数据是否打乱
num_workers=0) # 进程数, 0表示只有主进程
# 以循环形式输出
for data, target in tensor_dataloader:
print(data, target)
# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())
结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示:
dataset:Dataset 类型,输入的数据集,必须参数;
batch_size:int 类型,每个 batch 有多少个样本;
shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;
num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进
程,默认为 0。
什么是 Torchvision
PyTroch 官方为我们提供了一些常用的图片数据集,如果你需要读取这些数据集,那么无需自己实现,只需要利用 Torchvision 就可以搞定。
Torchvision 是一个和 PyTorch 配合使用的 Python 包。它不只提供了一些常用数据集,还提供了几个已经搭建好的经典网络模型,以及集成了一些图像数据处理方面的工具,主要供数据预处理阶段使用。简单地说,Torchvision 库就是常用数据集 + 常见网络模型 +常用图像处理方法。
Torchvision 的安装方式同样非常简单,可以使用 conda 安装,命令如下:
conda install torchvision -c pytorch
或使用 pip 进行安装,命令如下:
pip install torchvision
Torchvision 中默认使用的图像加载器是 PIL,因此为了确保 Torchvision 正常运行,我们还需要安装一个 Python 的第三方图像处理库——Pillow 库。Pillow 提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。
使用 conda 安装 Pillow 的命令如下:
conda install pillow
使用 pip 安装 Pillow 的命令如下:
pip install pillow
安装好 Torchvision 之后,我们再来接着看看。Torchvision 库为我们读取数据提供了哪些支持。
Torchvision 库中的torchvision.datasets包中提供了丰富的图像数据集的接口。常用的图像数据集,例如 MNIST、COCO 等,这个模块都为我们做了相应的封装。
下表中列出了torchvision.datasets包所有支持的数据集。各个数据集的说明与接口,详见链接:Datasets — Torchvision 0.15 documentation
这里我想提醒你注意,torchvision.datasets这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。
Torchvision:数据增强,让数据更加多样性
上面,我们一同迈出了训练开始的第一步——数据读取,初步认识了 Torchvision,学习了如何利用 Torchvision 读取数据。不过仅仅将数据集中的图片读取出来是不够的,在训练的过程中,神经网络模型接收的数据类型是 Tensor,而不是 PIL 对象,因此我们还需要对数据进行预处理操作,比如图像格式的转换。
与此同时,加载后的图像数据可能还需要进行一系列图像变换与增强操作,例如裁切边框、调整图像比例和大小、标准化等,以便模型能够更好地学习到数据的特征。这些操作都可以使用torchvision.transforms工具完成。
今天我们就来学习一下,利用 Torchvision 如何进行数据预处理操作,如何进行图像变换与增强。
图像处理工具之 torchvision.transforms
Torchvision 库中的torchvision.transforms包中提供了常用的图像操作,包括对Tensor 及 PIL Image 对象的操作,例如随机切割、旋转、数据类型转换等等。
按照torchvision.transforms 的功能,大致分为以下几类:数据类型转换、对PIL.Image 和 Tensor 进行变化和变换的组合。下面我们依次来学习这些类别中的操作。
数据类型转换
上面,我们学习了读取数据集中的图片,读取到的数据是 PIL.Image 的对象。而在模型训练阶段,需要传入 Tensor 类型的数据,神经网络才能进行运算。
那么如何将 PIL.Image 或 Numpy.ndarray 格式的数据转化为 Tensor 格式呢?这需要用到transforms.ToTensor() 类。
torch.transforms
在torchvision.transforms中常用的数据变换操作:
torchvision.transforms.Resize:对图片数据按需求大小进行缩放,传递的参数为整型,(h,w)h表示高度,w表示高度
torchvision.transforms.Scale:对图片按需求大小进行缩放
torchvision.transforms.CenterCrop:对图片以图片中心为参考点,按需求大小进行裁剪
torchvision.transforms.RandomCrop:对图片按需求大小进行随机裁剪
torchvision.transforms.RandomHorizontalFlip:对图片按随机概率进行水平翻转,概率默认值为0.5
torchvision.transforms.RandomVerticalFlip:对图片按随机概率进行垂直翻转,概率默认值为0.5
torchvision.transforms.ToTensor:对图片进行类型转换,之前构成PIL图片的数据转换成Tensor数据类型的变量,让Pytorch能够进行计算和处理
torchvision.transforms.ToPILImage:将Tensor变量的数据转换为PIL图片数据,方便图片内容的显示
以上是关于Torchvision:对数据进行操作的主要内容,如果未能解决你的问题,请参考以下文章
(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解