Pytorch Dataset和Dataloader 学习笔记

Posted Mtune

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch Dataset和Dataloader 学习笔记相关的知识,希望对你有一定的参考价值。

Pytorch Dataset & Dataloader

Pytorch框架下的工具包中,提供了数据处理的两个重要接口,Dataset 和 Dataloader,能够方便的使用和按批装载自己的数据集。

  1. 数据的预处理,加载数据并转化为tensor格式

  2. 使用Dataset构建自己的数据

  3. 使用Dataloader装载数据

【数据】链接:https://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA
提取码:10d4
复制这段内容后打开百度网盘手机App,操作更方便哦

数据的预处理与加载

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 数据的处理,加载转化为tensor
x_data = \'X.csv\'
y_data = \'y.csv\'
x = np.loadtxt(x_data, delimiter=\' \', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=\' \', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

torch.utils.data.Dataset

Dataset抽象类,用于包装构建自己的数据集,该类包括三个基本的方法:

  • __init__ 进行数据的读取操作
  • __getitem__ 数据集需支持索引访问
  • __len__ 返回数据集的长度
## 2. 构建自己的数据集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)
samples = dataset.__len__()
print("总样本数:",samples)

torch.utils.data.Dataloader

Dataloader抽象类,构建可迭代的数据集装载器,从Dataset实例对象中按batch_size装载数据以送入训练。包含以下几个参数:

  • batch_size 批大小
  • shuffle 装载的batch是否乱序
  • drop_last 不足batch大小的最后部分是否舍去
  • num_workers 是否多进程读取数据
## 3. 创建数据集装载器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)

测试

if __name__ == "__main__":
    iteration = 0
    for train_data, train_label in train_loader:
        print("x: ", train_data, "\\ny: ", train_label)
        iteration += 1
    ### 这里dataloader中drop_last为True,所以迭代次数应为 samples/batch_size = 6
    print("每个epoch迭代次数:",iteration)

完整代码

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 数据的处理,加载转化为tensor
x_data = \'X.csv\'
y_data = \'y.csv\'
x = np.loadtxt(x_data, delimiter=\' \', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=\' \', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

## 2. 构建自己的数据集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)

## 3. 创建数据集装载器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)

if __name__ == "__main__":
    iteration = 0
    samples = dataset.__len__()
    print("总样本数:", samples)
    for train_data, train_label in train_loader:
        print("x: ", train_data, "\\ny: ", train_label)
        iteration += 1
    ### 这里dataloader中drop_last为True,所以迭代次数应为 samples/batch_size = 6
    print("每个epoch迭代次数:",iteration)

以上是关于Pytorch Dataset和Dataloader 学习笔记的主要内容,如果未能解决你的问题,请参考以下文章

pytorch-lightning train_dataloader 用完数据

Pytorch Dataset和Dataloader 学习笔记

Pytorch的Dataset与Dataloader之间的关系

pytorch中的数据导入之DataLoader和Dataset的使用介绍

PyTorch 神经网络搭建模板

pytorch dataset dataloader