Pytorch Note52 灵活的数据读取介绍

Posted 风信子的猫Redamancy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch Note52 灵活的数据读取介绍相关的知识,希望对你有一定的参考价值。

Pytorch Note52 灵活的数据读取介绍


全部笔记的汇总贴: Pytorch Note 快乐星球

图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

灵活的数据读取

首先导入我们需要的函数

from torchvision.datasets import ImageFolder

文件中数据分布是这样的,每个文件夹中有三张图片

读入数据

# 三个文件夹,每个文件夹一共有 3 张图片作为例子
folder_set = ImageFolder('./example_data/image/')
# 查看名称和类别下标的对应
folder_set.class_to_idx
{'class_1': 0, 'class_2': 1, 'class_3': 2}
# 得到所有的图片名字和标签
folder_set.imgs
[('./example_data/image/class_1/1.png', 0),
('./example_data/image/class_1/2.png', 0),
('./example_data/image/class_1/3.png', 0),
('./example_data/image/class_2/10.png', 1),
('./example_data/image/class_2/11.png', 1),
('./example_data/image/class_2/12.png', 1),
('./example_data/image/class_3/16.png', 2),
('./example_data/image/class_3/17.png', 2),
('./example_data/image/class_3/18.png', 2)]
# 取出其中一个数据
im, label = folder_set[0]
im

label
0

传入数据预处理方式

from torchvision import transforms as tfs
# 传入数据预处理方式
data_tf = tfs.ToTensor()

folder_set = ImageFolder('./example_data/image/', transform=data_tf)

im, label = folder_set[0]
im
tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
      [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
      [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
      ...,
      [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
      [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
      [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

     [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
      [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
      [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
      ...,
      [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
      [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
      [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

     [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
      [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
      [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],
      ...,
      [0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],
      [0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],
      [0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]])
label
0

可以看到通过这种方式能够非常方便的访问每个数据点

Dataset

from torch.utils.data import Dataset
# 定义一个子类叫 custom_dataset,继承与 Dataset
class custom_dataset(Dataset):
    def __init__(self, txt_path, transform=None):
        self.transform = transform # 传入数据预处理
        with open(txt_path, 'r') as f:
            lines = f.readlines()
        
        self.img_list = [i.split()[0] for i in lines] # 得到所有的图像名字
        self.label_list = [i.split()[1] for i in lines] # 得到所有的 label 

    def __getitem__(self, idx): # 根据 idx 取出其中一个
        img = self.img_list[idx]
        label = self.label_list[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self): # 总数据的多少
        return len(self.label_list)
txt_dataset = custom_dataset('./example_data/train.txt') # 读入 txt 文件
# 取得其中一个数据
data, label = txt_dataset[0]
print(data)
print(label)
1009_2.png
YOU
# 再取一个
data2, label2 = txt_dataset[34]
print(data2)
print(label2)
1046_7.png
LIFE

所以通过这种方式我们也能够非常方便的定义一个数据读入,同时也能够方便的定义数据预处理

DataLoader

from torch.utils.data import DataLoader
train_data1 = DataLoader(folder_set, batch_size=2, shuffle=True) # 将 2 个数据作为一个 batch
for im, label in train_data1: # 访问迭代器
    print(label)
tensor([0, 2])
tensor([1, 1])
tensor([1, 2])
tensor([0, 0])
tensor([2])

可以看到,通过训练我们可以访问到所有的数据,这些数据被分为了 5 个 batch,前面 4 个都有两个数据,最后一个 batch 只有一个数据,因为一共有 9 个数据,同时顺序也被打乱了

例子

下面我们用自定义的数据读入举例子

train_data2 = DataLoader(txt_dataset, 8, True) # batch size 设置为 8
im, label = next(iter(train_data2)) # 使用这种方式访问迭代器中第一个 batch 的数据
im
('377_10.png',
'178_1.png',
'5008_4.png',
'5050_5.png',
'716_3.png',
'415_8.png',
'858_6.png',
'5086_10.png')
label
('AUGUST',
'OTKRIJTE',
'ASTAIRE',
'BOONMEE',
'OF',
'CAUTION',
'PROPANE',
'PECC')

现在有一个需求,希望能够将上面一个 batch 输出的 label 补成相同的长度,短的 label 用 0 填充,我们就需要使用 collate_fn 来自定义我们 batch 的处理方式,下面直接举例子

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True) # 将数据集按照 label 的长度从大到小排序
    img, label = zip(*batch) # 将数据和 label 配对取出
    # 填充
    pad_label = []
    lens = []
    max_len = len(label[0])
    for i in range(len(label)):
        temp_label = label[i]
        temp_label += '0' * (max_len - len(label[i]))
        pad_label.append(temp_label)
        lens.append(len(label[i]))
    pad_label 
    return img, pad_label, lens # 输出 label 的真实长度

使用我们自己定义 collate_fn 看看效果

train_data3 = DataLoader(txt_dataset, 8, True, collate_fn=collate_fn) # batch size 设置为 8
im, label, lens = next(iter(train_data3))
im
('5016_1.png',
'2314_3.png',
'731_9.png',
'5019_4.png',
'208_4.png',
'5017_12.png',
'5190_1.png',
'855_12.png')
label
['LINDSAY',
'ADDRESS',
'MAIDEN0',
'EINER00',
'INDIA00',
'GERE000',
'JAWS000',
'TD00000']
lens
[7, 7, 6, 5, 5, 4, 4, 2]

可以看到一个 batch 中所有的 label 都从长到短进行排列,同时短的 label 都被补长了,所以使用 collate_fn 能够非常方便的处理一个 batch 中的数据,一般情况下,没有特别的要求,使用 pytorch 中内置的 collate_fn 就可以满足要求了

接下来是第二种情况,也是较为复杂的情况

第二种情况就是,所有文件都在一个文件夹下有图片,还必须要有事先标注好的label标签文件。

制作个人分类用数据集具体步骤如下:
1、将个人收集的图片归到一个文件夹内如下图:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image


def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

train_data=MyDataset(txt='D:/CIFAR-10/images/data/label.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=1,shuffle=True)
print(len(data_loader))
data_loader.dataset.imgs
[('D:/CIFAR-10/images/data/0.jpg', 0),
 ('D:/CIFAR-10/images/data/1.jpg', 1),
 ('D:/CIFAR-10/images/data/2.jpg', 0),
 ('D:/CIFAR-10/images/data/3.jpg', 1),
 ('D:/CIFAR-10/images/data/4.jpg', 1),
 ('D:/CIFAR-10/images/data/5.jpg', 0),
 ('D:/CIFAR-10/images/data/6.jpg', 1),
 ('D:/CIFAR-10/images/data/7.jpg', 1),
 ('D:/CIFAR-10/images/data/8.jpg', 0),
 ('D:/CIFAR-10/images/data/9.jpg', 0)]

其实和例子是类似的

Image.open(data_loader.dataset.imgs[0][0])

以上是关于Pytorch Note52 灵活的数据读取介绍的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch Note 快乐星球

Pytorch Note 快乐星球

Pytorch Note50 Gym 介绍

Pytorch Note13 反向传播算法

Pytorch Note8 简单介绍torch.optim(优化)和模型保存

Pytorch Note46 生成对抗网络的数学原理