pytorch 笔记:DataLoader 扩展:构造图片DataLoader

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:DataLoader 扩展:构造图片DataLoader相关的知识,希望对你有一定的参考价值。

数据来源:OneDrive for Business

涉及内容:pytorch笔记:Dataloader_UQI-LIUWJ的博客-CSDN博客

torchvision 笔记:ToTensor()_UQI-LIUWJ的博客-CSDN博客

torchvision 笔记:transforms.Normalize()_UQI-LIUWJ的博客-CSDN博客

torchvision 笔记:transforms.Compose()_UQI-LIUWJ的博客-CSDN博客

1 数据格式

在windows的cmd上敲下 tree /F :

─img
│      00000.jpg
│      00001.jpg
│      00002.jpg
│      00003.jpg
│      00004.jpg
│      00005.jpg
.....
|
|
│      06998.jpg
│      06999.jpg
│
└─split
        list_attr_cloth.txt
        test.txt
        test_bbox.txt
        test_landmards.txt
        train.txt
        train_attr.txt
        train_bbox.txt
        train_landmards.txt
        val.txt
        val_attr.txt
        val_bbox.txt
        val_landmards.txt

我们这里先只用train.txt和train_attr.txt

1.1  train.txt

我们只看前五行

img/00000.jpg
img/00001.jpg
img/00002.jpg
img/00003.jpg
img/00004.jpg

1.2 train_attrr.txt

也是只看前五行(每一行是这张图片在这6个类上所属的类别)

5 0 2 0 2 2
5 1 2 0 5 1
5 0 2 3 4 2
6 2 1 3 2 2
0 2 1 3 2 2

2 创建DataLoader

2.1 导入库

from PIL import Image
from torchvision import transforms, utils
import torch
from torch.utils.data import Dataset, DataLoader

2.2 preprocess

对于每一张输入的image进行ToTensor和归一化的操作

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
            )
])

2.3 从路径加载图片-->图片成为Tensor

def default_loader(path):
    img_pil =  Image.open(path)
    img_pil = img_pil.resize((224,224))
    img_tensor = preprocess(img_pil)
    return img_tensor

2.4 定义DataSet

还是需要实现__getitem__和__len__操作

class trainset(Dataset):
    def __init__(self,
                 loader=default_loader,
                 img_path='split/train.txt',
                 tgt_path='split/train_attr.txt',
                 attr_no=0):
        
        self.images = open(img_path,'r')#img_path代表的文件是进行训练的图片路径的合集
        self.f_tmp=self.images.readlines()#
        self.target = open(tgt_path,'r')#img_path代表的文件是进行训练的图片标签的合集
        self.t_tmp=self.target.readlines()
        self.loader = loader #从路径中读取图片->变成Tensor
        self.attr_no=attr_no

    def __getitem__(self, index):
        fn = self.f_tmp[index].strip()#'img/00001.jpg'
        img = self.loader(fn)#从路径中读取图片->变成Tensor
        tt = self.t_tmp[index].strip()[self.attr_no]#由于有六个标签,我们一个一个设置
        return img,tt

    def __len__(self):
        return len(self.f_tmp)

2.5  创建DataLoader

loader=DataLoader(
    trainset(),
    batch_size=4,
    shuffle=True)

2.6 查看 效果

四张图片以及对应的标签

for step,(batch_x,batch_x_y) in enumerate(loader):
    print(batch_x,batch_x_y)
    break
'''
tensor([[[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          ...,
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489]],

         [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          ...,
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286]],

         [[2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          ...,
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400]]],


        [[[2.1633, 2.1633, 2.1633,  ..., 2.1804, 2.1804, 2.1804],
          [2.1633, 2.1633, 2.1633,  ..., 2.1804, 2.1804, 2.1804],
          [2.1633, 2.1633, 2.1633,  ..., 2.1804, 2.1804, 2.1804],
          ...,
          [2.0777, 2.0777, 2.0605,  ..., 2.1290, 2.1290, 2.1290],
          [2.1119, 2.1119, 2.1119,  ..., 2.1290, 2.1290, 2.1290],
          [2.1462, 2.1462, 2.1290,  ..., 2.1119, 2.1119, 2.1119]],

         [[2.3410, 2.3410, 2.3410,  ..., 2.3585, 2.3585, 2.3585],
          [2.3410, 2.3410, 2.3410,  ..., 2.3585, 2.3585, 2.3585],
          [2.3410, 2.3410, 2.3410,  ..., 2.3585, 2.3585, 2.3585],
          ...,
          [2.2360, 2.2360, 2.2185,  ..., 2.3060, 2.3060, 2.3060],
          [2.2710, 2.2710, 2.2710,  ..., 2.3060, 2.3060, 2.3060],
          [2.3060, 2.3060, 2.2885,  ..., 2.2885, 2.2885, 2.2885]],

         [[2.5529, 2.5529, 2.5529,  ..., 2.5703, 2.5703, 2.5703],
          [2.5529, 2.5529, 2.5529,  ..., 2.5703, 2.5703, 2.5703],
          [2.5529, 2.5529, 2.5529,  ..., 2.5703, 2.5703, 2.5703],
          ...,
          [2.4134, 2.4134, 2.3960,  ..., 2.5180, 2.5180, 2.5180],
          [2.4483, 2.4483, 2.4483,  ..., 2.5180, 2.5180, 2.5180],
          [2.4831, 2.4831, 2.4657,  ..., 2.5006, 2.5006, 2.5006]]],


        [[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          ...,
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2318, 2.2318, 2.2318,  ..., 2.2489, 2.2489, 2.2489]],

         [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          ...,
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4111, 2.4111, 2.4111,  ..., 2.4286, 2.4286, 2.4286]],

         [[2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          ...,
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6226, 2.6226, 2.6226,  ..., 2.6400, 2.6400, 2.6400]]],


        [[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          ...,
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
          [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489]],

         [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          ...,
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
          [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286]],

         [[2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          ...,
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400],
          [2.6400, 2.6400, 2.6400,  ..., 2.6400, 2.6400, 2.6400]]]]) ('1', '1', '0', '3')
'''

以上是关于pytorch 笔记:DataLoader 扩展:构造图片DataLoader的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch学习笔记:数据读取机制(DataLoader与Dataset)

Pytorch Dataset和Dataloader 学习笔记

「深度学习一遍过」必修3:Pytorch数据读取——使用Dataloader读取Dataset

Pytorch的Dataset与Dataloader之间的关系

PyTorch学习笔记之DataLoaders

PyTorch DataLoader()使用