PyTorch:数据读取1 - Datasets及数据集划分
Posted -柚子皮-
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch:数据读取1 - Datasets及数据集划分相关的知识,希望对你有一定的参考价值。
什么是Datasets?
在输入流水线中,准备数据的代码是这么写的
data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10
就是一个Datasets
子类,data
是这个类的一个实例。
为什么要定义Datasets?
PyTorch
提供了一个工具函数torch.utils.data.DataLoader
。通过这个类,我们可以让数据变成mini-batch,且在准备mini-batch
的时候可以多线程并行处理,这样可以加快准备数据的速度。
Datasets
就是构建这个类的实例的参数之一。
DataLoader的使用参考[
PyTorch:数据读取2 - Dataloader]。
数据集划分
1 建议使用sklearn.preprocessing.model_selection
ds_train, ds_eval = model_selection.train_test_split(dataset, test_size=0.2, shuffle=args.if_shuffle_data)
2 train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
Note: dataloader应该是不能进行划分的。
自定义Datasets
框架
dataset
必须继承自torch.utils.data.Dataset。
内部要实现两个函数:一个是__lent__
用来获取整个数据集的大小,一个是__getitem__
用来从数据集中得到一个数据片段item
。
import torch.utils.data as data
class CustomDataset(data.Dataset): # 继承data.Dataset
"""Custom data.Dataset compatible with data.DataLoader."""
def __init__(self, filename, data_info, oth_params):
"""Reads source and target sequences from txt files."""
# # # Initialize file path or list of file names.
self.file = open(filename, 'r')
pass
# # # 或者从外部数据结构data_info中读取数据
self.all_texts = data_info['all_texts']
self.all_labels = data_info['all_labels']
self.vocab = data_info['vocab']
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
# # # 从文件读取
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform或者word2id什么的).
# 3. Return a data pair(source and target) (e.g. image and label).
pass
# # # 或者直接读取
item_info =
"text": self.all_texts[index],
"label": self.all_labels[index]
return item_info
def __len__(self):
# You should change 0 to the total size of your dataset.
# return 0
return len(self.all_texts)
小示例
从文件中读取数据定稿Dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, filepath=None,dataLen=None):
self.file = filepath
self.dataLen = dataLen
def __getitem__(self, index):
A,B,path,hop= linecache.getline(self.file, index+1).split('\\t')
return A,B,path.split(' '),int(hop)
def __len__(self):
return self.dataLen
随机mock一个分类数据
class Dataset(data.Dataset):
"""Custom data.Dataset compatible with data.DataLoader."""
def __init__(self, df, lang: Lang):
inputs_dim = vars(Config)['inputs_dim']
self.x = torch.randint(0, 5, (5, inputs_dim), dtype=torch.float)
self.label = torch.tensor([0, 0, 1, 1, 0, 1, 0, 1, 0, 1], dtype=torch.float)
self.src_word2id = lang.word2id
self.trg_word2id = lang.word2id
# self.mem_word2id = mem_word2id
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
x = self.x[index]
label = self.label[index]
item_info =
"x": x,
"label": label
return item_info
官方MNIST
的例子
(代码被缩减,只留下了重要的部分):
class MNIST(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return 60000
else:
return 10000
from: -柚子皮-
ref: [pytorch学习笔记(六):自定义Datasets]
以上是关于PyTorch:数据读取1 - Datasets及数据集划分的主要内容,如果未能解决你的问题,请参考以下文章