pytorch 笔记:DataLoader 扩展:构造图片DataLoader
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:DataLoader 扩展:构造图片DataLoader相关的知识,希望对你有一定的参考价值。
涉及内容:pytorch笔记:Dataloader_UQI-LIUWJ的博客-CSDN博客
torchvision 笔记:ToTensor()_UQI-LIUWJ的博客-CSDN博客
torchvision 笔记:transforms.Normalize()_UQI-LIUWJ的博客-CSDN博客
torchvision 笔记:transforms.Compose()_UQI-LIUWJ的博客-CSDN博客
1 数据格式
在windows的cmd上敲下 tree /F :
─img
│ 00000.jpg
│ 00001.jpg
│ 00002.jpg
│ 00003.jpg
│ 00004.jpg
│ 00005.jpg
.....
|
|
│ 06998.jpg
│ 06999.jpg
│
└─split
list_attr_cloth.txt
test.txt
test_bbox.txt
test_landmards.txt
train.txt
train_attr.txt
train_bbox.txt
train_landmards.txt
val.txt
val_attr.txt
val_bbox.txt
val_landmards.txt
我们这里先只用train.txt和train_attr.txt
1.1 train.txt
我们只看前五行
img/00000.jpg
img/00001.jpg
img/00002.jpg
img/00003.jpg
img/00004.jpg
1.2 train_attrr.txt
也是只看前五行(每一行是这张图片在这6个类上所属的类别)
5 0 2 0 2 2
5 1 2 0 5 1
5 0 2 3 4 2
6 2 1 3 2 2
0 2 1 3 2 2
2 创建DataLoader
2.1 导入库
from PIL import Image
from torchvision import transforms, utils
import torch
from torch.utils.data import Dataset, DataLoader
2.2 preprocess
对于每一张输入的image进行ToTensor和归一化的操作
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
2.3 从路径加载图片-->图片成为Tensor
def default_loader(path):
img_pil = Image.open(path)
img_pil = img_pil.resize((224,224))
img_tensor = preprocess(img_pil)
return img_tensor
2.4 定义DataSet
还是需要实现__getitem__和__len__操作
class trainset(Dataset):
def __init__(self,
loader=default_loader,
img_path='split/train.txt',
tgt_path='split/train_attr.txt',
attr_no=0):
self.images = open(img_path,'r')#img_path代表的文件是进行训练的图片路径的合集
self.f_tmp=self.images.readlines()#
self.target = open(tgt_path,'r')#img_path代表的文件是进行训练的图片标签的合集
self.t_tmp=self.target.readlines()
self.loader = loader #从路径中读取图片->变成Tensor
self.attr_no=attr_no
def __getitem__(self, index):
fn = self.f_tmp[index].strip()#'img/00001.jpg'
img = self.loader(fn)#从路径中读取图片->变成Tensor
tt = self.t_tmp[index].strip()[self.attr_no]#由于有六个标签,我们一个一个设置
return img,tt
def __len__(self):
return len(self.f_tmp)
2.5 创建DataLoader
loader=DataLoader(
trainset(),
batch_size=4,
shuffle=True)
2.6 查看 效果
四张图片以及对应的标签
for step,(batch_x,batch_x_y) in enumerate(loader):
print(batch_x,batch_x_y)
break
'''
tensor([[[[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
...,
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489]],
[[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
...,
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286]],
[[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
...,
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400]]],
[[[2.1633, 2.1633, 2.1633, ..., 2.1804, 2.1804, 2.1804],
[2.1633, 2.1633, 2.1633, ..., 2.1804, 2.1804, 2.1804],
[2.1633, 2.1633, 2.1633, ..., 2.1804, 2.1804, 2.1804],
...,
[2.0777, 2.0777, 2.0605, ..., 2.1290, 2.1290, 2.1290],
[2.1119, 2.1119, 2.1119, ..., 2.1290, 2.1290, 2.1290],
[2.1462, 2.1462, 2.1290, ..., 2.1119, 2.1119, 2.1119]],
[[2.3410, 2.3410, 2.3410, ..., 2.3585, 2.3585, 2.3585],
[2.3410, 2.3410, 2.3410, ..., 2.3585, 2.3585, 2.3585],
[2.3410, 2.3410, 2.3410, ..., 2.3585, 2.3585, 2.3585],
...,
[2.2360, 2.2360, 2.2185, ..., 2.3060, 2.3060, 2.3060],
[2.2710, 2.2710, 2.2710, ..., 2.3060, 2.3060, 2.3060],
[2.3060, 2.3060, 2.2885, ..., 2.2885, 2.2885, 2.2885]],
[[2.5529, 2.5529, 2.5529, ..., 2.5703, 2.5703, 2.5703],
[2.5529, 2.5529, 2.5529, ..., 2.5703, 2.5703, 2.5703],
[2.5529, 2.5529, 2.5529, ..., 2.5703, 2.5703, 2.5703],
...,
[2.4134, 2.4134, 2.3960, ..., 2.5180, 2.5180, 2.5180],
[2.4483, 2.4483, 2.4483, ..., 2.5180, 2.5180, 2.5180],
[2.4831, 2.4831, 2.4657, ..., 2.5006, 2.5006, 2.5006]]],
[[[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
...,
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2318, 2.2318, 2.2318, ..., 2.2489, 2.2489, 2.2489]],
[[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
...,
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4111, 2.4111, 2.4111, ..., 2.4286, 2.4286, 2.4286]],
[[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
...,
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6226, 2.6226, 2.6226, ..., 2.6400, 2.6400, 2.6400]]],
[[[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
...,
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489]],
[[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
...,
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286]],
[[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
...,
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400]]]]) ('1', '1', '0', '3')
'''
以上是关于pytorch 笔记:DataLoader 扩展:构造图片DataLoader的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch学习笔记:数据读取机制(DataLoader与Dataset)
Pytorch Dataset和Dataloader 学习笔记
「深度学习一遍过」必修3:Pytorch数据读取——使用Dataloader读取Dataset