paddlepaddle实现十二生肖的分类之数据的预处理

Posted 修炼之路

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了paddlepaddle实现十二生肖的分类之数据的预处理相关的知识,希望对你有一定的参考价值。

数据集说明

数据集一共包含3个目录trainvalidtest,每个目录都包含了12生肖(类别)的图片,通过下面的链接可以直接下载数据集

数据下载地址:下载地址
项目地址:项目链接

数据分析

统计数据集中每个类别的数据分布情况

import os

def print_classes_info(mode="train",data_dir = "data/signs"):
    datasets_dir = os.path.join(data_dir,mode)
    classes_names = os.listdir(datasets_dir)
    #用来保存每个类别的数量信息
    classes_num_infos = dict()
    for class_name in classes_names:
        #获取类别的目录
        class_dir_path = os.path.join(datasets_dir,class_name)
        img_names = os.listdir(class_dir_path)
        #记录每个类别的图片数量
        classes_num_infos[class_name] = len(img_names)
    print(":".format(mode,classes_num_infos))

#打印数据的分布情况
print_classes_info("train")
print_classes_info("valid")
print_classes_info("test")

train:‘goat’: 600, ‘tiger’: 600, ‘horse’: 600, ‘snake’: 600, ‘pig’: 600, ‘dragon’: 600, ‘ox’: 600, ‘monkey’: 600, ‘rabbit’: 600, ‘rooster’: 600, ‘dog’: 600, ‘ratt’: 600
valid:‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55
test:‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55

在训练集中每个类别包含600张图片,验证集中每个类别包含55张图片,测试集中每个类别包含55张图片,因为这里的数据都比较平衡,后面我们就不需要去考虑数据的平衡问题了。

展示图片

从数据集中选择一部分数据进行查看,了解数据的分布特征

import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

def show_images(mode="train",data_dir = "data/signs",row_num=3,col_num=4):
    datasets_dir = os.path.join(data_dir,mode)
    #获取所有的类别数据
    classes_names = os.listdir(datasets_dir)
    #用来保存图片数据
    images_list = []
    title_list = []
    for cls_name in classes_names:
        cls_dir_path = os.path.join(datasets_dir,cls_name)
        img_name_list = os.listdir(cls_dir_path)
        for img_name in img_name_list:
            img_path = os.path.join(cls_dir_path,img_name)
            image = Image.open(img_path)
            #保存图片和标签数据
            images_list.append(np.array(image))
            title_list.append(cls_name)
            break
            
    #设置图片的大小
    plt.figure(figsize=(8,8))
    for index in range(row_num*col_num):
        plt.subplot(row_num,col_num,index+1)
        plt.imshow(images_list[index])
        plt.title(title_list[index])
        #隐藏x轴和y轴的标签刻度
        plt.xticks([])
        plt.yticks([])
    
    plt.show()


show_images()

数据加载器

基于paddlepaddle提供的paddle.io.Dataset类,封装一个十二生肖的数据加载器,用于后面的模型训练和评估,将图片的预处理也封装在里面

import os
import paddle
from paddle.vision import transforms
from PIL import Image
import numpy as np

class ZodiacDatasets(paddle.io.Dataset):
    """
    加载十二生肖数据
    """
    def __init__(self,mode="train",data_root="data/signs",img_size=(224,224)):
        self.data_root = data_root
        #判断mode是否正确
        if mode not in ["train","valid","test"]:
            assert(" is illegal,mode need is one of train,valid,test")
        #获取数据集的目录
        self._data_dir_path = os.path.join(data_root,mode)
        #获取十二生肖的类别名称
        self._zodiac_names = sorted(os.listdir(self._data_dir_path))
        #用来保存图片的路径
        self._img_path_list = []
        for name in self._zodiac_names:
            img_dir_path = os.path.join(self._data_dir_path,name)
            img_name_list = os.listdir(img_dir_path)
            for img_name in img_name_list:
                img_path = os.path.join(img_dir_path,img_name)
                self._img_path_list.append(img_path)
        #定义图像的预处理函数
        if mode == "train":
            self._transform = transforms.Compose([
                transforms.RandomResizedCrop(img_size),   #缩放图片并随机裁剪图片为指定shape
                transforms.RandomHorizontalFlip(0.5),     #随机水平翻转图片的概率为0.5
                transforms.ToTensor(),                    #转换图片的格式由HWC ==> CHW
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])  #图片通道像素的标准化
            ])
        else:
            self._transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
            ])
    def __getitem__(self,index):
        """根据index获取图片数据
        """
        #获取图片的路径
        img_path = self._img_path_list[index]
        #获取图片的标签
        img_label = img_path.split("/")[-2]
        #将生肖的标签名称转换为数字标签
        label_index = self._zodiac_names.index(img_label)
        #读取图片
        img = Image.open(img_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        #图片的预处理
        img = self._transform(img)
        return img,np.array(label_index,dtype=np.int64)

    def __len__(self):
        """获取数据集的大小
        """
        return len(self._img_path_list)


#加载训练集
train_datasets = ZodiacDatasets(mode="train")
#统计训练集的大小
print(len(train_datasets))
for img,img_label in train_datasets:
    print(img.shape,img_label)
    break

以上是关于paddlepaddle实现十二生肖的分类之数据的预处理的主要内容,如果未能解决你的问题,请参考以下文章

paddlepaddle十二生肖分类之模型训练和预测

paddlepaddle十二生肖分类之模型训练和预测

paddlepaddle十二生肖分类之模型(ResNet)构建

paddlepaddle十二生肖分类之模型(ResNet)构建

paddlepaddle十二生肖分类之模型(ResNet)构建

一文教你用paddlepaddle实现猫狗分类