DataLoader的使用
Posted ^_^|
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DataLoader的使用相关的知识,希望对你有一定的参考价值。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其中最重要的当属dataset一项,pytorch支持两种类型的dataset:
- map-style datasets
- iterable-style datasets
对于map-style.dataset类型,它需要_getitem__
()and __len__()
这两个函数
下面我们以原始数据’abcdefg’为例进行说明。注意两个函数的编写,以及torch.utils.data.DataLoader()
的参数变化
将’abcdefg’顺序遍历
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
------output----------
['a']
['b']
['c']
['d']
['e']
['f']
['g']
shuffle=True,进行打乱,随机取出
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
------output----------------
['f']
['a']
['d']
['e']
['c']
['g']
['b']
改变batch_size
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output-------------
['d', 'c']
['f', 'b']
['e', 'a']
['g']
改写_getitem__
()and __len__()
以达到自己想要的结果
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx], self.data[idx].upper()
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output-----------
[('a', 'b'), ('A', 'B')]
[('c', 'd'), ('C', 'D')]
[('e', 'f'), ('E', 'F')]
[('g',), ('G',)]
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
if idx >= len(self.data): # if the index >= 26, return upper case letter
return self.data[idx%7].upper()
else: # if the index < 26, return lower case, return lower case letter
return self.data[idx]
def __len__(self): # What is the length of the dataset
return 2 * len(self.data) # The length is now twice as large
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output------------
['a', 'b']
['c', 'd']
['e', 'f']
['g', 'A']
['B', 'C']
['D', 'E']
['F', 'G']
以上是关于DataLoader的使用的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 PyTorch DataLoader 进行强化学习?
Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存
小白学习PyTorch教程五在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据