DataLoader的使用

Posted ^_^|

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DataLoader的使用相关的知识,希望对你有一定的参考价值。

官方文档说明

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中最重要的当属dataset一项,pytorch支持两种类型的dataset:

  • map-style datasets
  • iterable-style datasets

对于map-style.dataset类型,它需要_getitem__()and __len__()这两个函数

下面我们以原始数据’abcdefg’为例进行说明。注意两个函数的编写,以及torch.utils.data.DataLoader()的参数变化

将’abcdefg’顺序遍历

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
  
------output----------
['a']
['b']
['c']
['d']
['e']
['f']
['g']

shuffle=True,进行打乱,随机取出

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
------output----------------
['f']
['a']
['d']
['e']
['c']
['g']
['b']

改变batch_size

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-------------
['d', 'c']
['f', 'b']
['e', 'a']
['g']

改写_getitem__()and __len__()以达到自己想要的结果

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx], self.data[idx].upper()
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-----------
[('a', 'b'), ('A', 'B')]
[('c', 'd'), ('C', 'D')]
[('e', 'f'), ('E', 'F')]
[('g',), ('G',)]
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    if idx >= len(self.data): # if the index >= 26, return upper case letter
      return self.data[idx%7].upper()
    else: # if the index < 26, return lower case, return lower case letter
      return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return 2 * len(self.data) # The length is now twice as large

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)

-----------output------------
['a', 'b']
['c', 'd']
['e', 'f']
['g', 'A']
['B', 'C']
['D', 'E']
['F', 'G']

以上是关于DataLoader的使用的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 PyTorch DataLoader 进行强化学习?

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存

小白学习PyTorch教程五在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

[YOLO专题-19]:YOLO V5 - ultralytics代码解析-dataloader数据加载机制

pytorch初学笔记:DataLoader的使用

graphql dataloader 无法读取未定义错误的属性“加载”