PyTorch 数据集类 和 数据加载类 的一些尝试
Posted devilmaycry812839668
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 数据集类 和 数据加载类 的一些尝试相关的知识,希望对你有一定的参考价值。
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实验,小尝试。
下面给出一个常用的数据类使用方式:
def data_tf(x): x = np.array(x, dtype=‘float32‘) / 255 # 将数据变到 0 ~ 1 之间 x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到 x = x.reshape((-1,)) # 拉平 x = torch.from_numpy(x) return x from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据 train_set = MNIST(‘./data‘, train=True, transform=data_tf, download=True) # 载入数据集,申明定义的数据变换 test_set = MNIST(‘./data‘, train=False, transform=data_tf, download=True)
其中, data_tf 并不是必须要有的,比如:
from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据 train_set = MNIST(‘./data‘, train=True, download=True) # 载入数据集,申明定义的数据变换 test_set = MNIST(‘./data‘, train=False, download=True)
这里面的MNIST类是框架自带的,可以自动下载MNIST数据库, ./data 是指将下载的数据集存放在当前目录下的哪个目录下, train 这个属性 True时 则在 ./data文件夹下面在建立一个 train的文件夹然后把下载的数据存放在其中, 当train属性是False的时候则把下载的数据放在 test文件夹下面。
划线部分是老版本的PyTorch的处理方式, 最近试了一下最新版本 PyTorch 1.0 , train为True的时候是把数据放在 ./data/processed 文件夹下面, 命名为training.pt , 为False 的时候则放在 ./data/processed 文件夹下面, 命名为test.pt 。
这时候就出现了一个问题, 如果你使用的数据集不是框架自带的那么如何使用数据类呢,这个时候就要使用 pytorch 中的 Dataset 类了。
from torch.utils.data import Dataset
我们需要重写 Dataset类, 需要实现的方法为 __len__ 和 __getitem__ 这两个内置方法, 这里可以看出其思想就是要重写的类需要支持按照索引查找的方法。
这里我们还是举个例子:
从这个例子可以看出 mydataset就是我们自定义的 myDataset 类生成的自定义数据类对象。我们可以在myDataset类中自定义一些方法来对需要的数据进行处理。
为说明该问题另附加一个例子:
from torch.utils.data import Dataset #需要在pytorch中使用的数据 data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]] class myDataset(Dataset): def __init__(self, indata): self.data=indata def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] mydataset=myDataset(data)
那么又来了一个问题,我们不重写 Dataset类的话可不可以呢, 经过尝试发现还真可以,如下:
又如:
由这个例子可以看出数据类对象可以不重写Dataset类, 只要具备 __len__ __getitem__ 方法就可以。而且从这个例子我们可以看出 DataLoader 是一个迭代器, 如果shuffle 设置为 True 那么在每次迭代之前都会重新排序。
同时由上面两个例子可以看出 DataLoader类会把传入的数据集合中的数据转化为 torch.tensor 类型, 当然是采用默认的 DataLoader类中转化函数 transform的情况下。
以上是关于PyTorch 数据集类 和 数据加载类 的一些尝试的主要内容,如果未能解决你的问题,请参考以下文章
pytorch中的数据加载(dataset基类,以及pytorch自带数据集)