DataLoader的使用
Posted 三つ叶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了DataLoader的使用相关的知识,希望对你有一定的参考价值。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其中最重要的当属 dataset 一项,pytorch 支持两种类型的 dataset:
- map-style datasets
- iterable-style datasets
对于 map-style.dataset 类型,它需要 _getitem__()
and __len__()
这两个函数
下面我们以原始数据 ‘abcdefg’ 为例进行说明。注意两个函数的编写,以及 torch.utils.data.DataLoader()
的参数变化
将’abcdefg’顺序遍历
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
------output----------
['a']
['b']
['c']
['d']
['e']
['f']
['g']
shuffle=True,进行打乱,随机取出
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
print(datapoint)
------output----------------
['f']
['a']
['d']
['e']
['c']
['g']
['b']
改变batch_size
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx]
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output-------------
['d', 'c']
['f', 'b']
['e', 'a']
['g']
改写_getitem__
()and __len__()
以达到自己想要的结果
import torch
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
return self.data[idx], self.data[idx].upper()
def __len__(self): # What is the length of the dataset
return len(self.data)
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output-----------
[('a', 'b'), ('A', 'B')]
[('c', 'd'), ('C', 'D')]
[('e', 'f'), ('E', 'F')]
[('g',), ('G',)]
import torch.utils.data
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = "abcdefg"
def __getitem__(self,idx): # if the index is idx, what will be the data?
if idx >= len(self.data): # if the index >= 26, return upper case letter
return self.data[idx%7].upper()
else: # if the index < 26, return lower case, return lower case letter
return self.data[idx]
def __len__(self): # What is the length of the dataset
return 2 * len(self.data) # The length is now twice as large
dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
print(datapoint)
-----------output------------
['a', 'b']
['c', 'd']
['e', 'f']
['g', 'A']
['B', 'C']
['D', 'E']
['F', 'G']
带有 transform 的读取本地图片的 Dataset
DATA_DIR = 'data/CIFAR-10'
DATABASE_FILE = 'database_img.txt'
DATABASE_LABEL = 'database_label.txt'
# 一般不同网络架构可能对输入的图片数据有格式要求,可以在此处做处理
# 当然是用 transforms 的操作除了满足网络输入的需求,同样还可以用作数据加强
transformations = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dset_database = ExampleDataset(
DATA_DIR, DATABASE_FILE, DATABASE_LABEL, transformations)
class ExampleDataset(Dataset):
def __init__(self, data_path, img_filename, label_filename, transform=None):
self.img_path = data_path
self.transform = transform
# reading img file from file
img_filepath = os.path.join(data_path, img_filename)
fp = open(img_filepath, 'r')
self.img_filename = [x.strip() for x in fp]
fp.close()
label_filepath = os.path.join(data_path, label_filename)
fp_label = open(label_filepath, 'r')
labels = [int(x.strip()) for x in fp_label]
fp_label.close()
self.label = labels
def __getitem__(self, index):
img = Image.open(os.path.join(self.img_path, self.img_filename[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.label[index]])
return img, label, index
def __len__(self):
return len(self.img_filename)
参考拓展:https://www.daimajiaoliu.com/daima/4ede05ecd1003fc
以上是关于DataLoader的使用的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 PyTorch 的 DataLoader 确保批次包含来自所有工作人员的样本?
PyTorch DataLoader 将批次作为列表返回,批次作为唯一条目。如何从我的 DataLoader 获取张量的最佳方式