paddlepaddle实现十二生肖的分类之数据的预处理
Posted 修炼之路
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了paddlepaddle实现十二生肖的分类之数据的预处理相关的知识,希望对你有一定的参考价值。
数据集说明
数据集一共包含3个目录train
、valid
和test
,每个目录都包含了12生肖(类别)的图片,通过下面的链接可以直接下载数据集
数据分析
统计数据集中每个类别的数据分布情况
import os
def print_classes_info(mode="train",data_dir = "data/signs"):
datasets_dir = os.path.join(data_dir,mode)
classes_names = os.listdir(datasets_dir)
#用来保存每个类别的数量信息
classes_num_infos = dict()
for class_name in classes_names:
#获取类别的目录
class_dir_path = os.path.join(datasets_dir,class_name)
img_names = os.listdir(class_dir_path)
#记录每个类别的图片数量
classes_num_infos[class_name] = len(img_names)
print(":".format(mode,classes_num_infos))
#打印数据的分布情况
print_classes_info("train")
print_classes_info("valid")
print_classes_info("test")
train:‘goat’: 600, ‘tiger’: 600, ‘horse’: 600, ‘snake’: 600, ‘pig’: 600, ‘dragon’: 600, ‘ox’: 600, ‘monkey’: 600, ‘rabbit’: 600, ‘rooster’: 600, ‘dog’: 600, ‘ratt’: 600
valid:‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55
test:‘goat’: 55, ‘tiger’: 55, ‘horse’: 55, ‘snake’: 55, ‘pig’: 55, ‘dragon’: 55, ‘ox’: 55, ‘monkey’: 55, ‘rabbit’: 55, ‘rooster’: 55, ‘dog’: 55, ‘ratt’: 55
在训练集中每个类别包含600张
图片,验证集中每个类别包含55张
图片,测试集中每个类别包含55
张图片,因为这里的数据都比较平衡,后面我们就不需要去考虑数据的平衡问题了。
展示图片
从数据集中选择一部分数据进行查看,了解数据的分布特征
import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
def show_images(mode="train",data_dir = "data/signs",row_num=3,col_num=4):
datasets_dir = os.path.join(data_dir,mode)
#获取所有的类别数据
classes_names = os.listdir(datasets_dir)
#用来保存图片数据
images_list = []
title_list = []
for cls_name in classes_names:
cls_dir_path = os.path.join(datasets_dir,cls_name)
img_name_list = os.listdir(cls_dir_path)
for img_name in img_name_list:
img_path = os.path.join(cls_dir_path,img_name)
image = Image.open(img_path)
#保存图片和标签数据
images_list.append(np.array(image))
title_list.append(cls_name)
break
#设置图片的大小
plt.figure(figsize=(8,8))
for index in range(row_num*col_num):
plt.subplot(row_num,col_num,index+1)
plt.imshow(images_list[index])
plt.title(title_list[index])
#隐藏x轴和y轴的标签刻度
plt.xticks([])
plt.yticks([])
plt.show()
show_images()
数据加载器
基于paddlepaddle提供的paddle.io.Dataset
类,封装一个十二生肖的数据加载器,用于后面的模型训练和评估,将图片的预处理也封装在里面
import os
import paddle
from paddle.vision import transforms
from PIL import Image
import numpy as np
class ZodiacDatasets(paddle.io.Dataset):
"""
加载十二生肖数据
"""
def __init__(self,mode="train",data_root="data/signs",img_size=(224,224)):
self.data_root = data_root
#判断mode是否正确
if mode not in ["train","valid","test"]:
assert(" is illegal,mode need is one of train,valid,test")
#获取数据集的目录
self._data_dir_path = os.path.join(data_root,mode)
#获取十二生肖的类别名称
self._zodiac_names = sorted(os.listdir(self._data_dir_path))
#用来保存图片的路径
self._img_path_list = []
for name in self._zodiac_names:
img_dir_path = os.path.join(self._data_dir_path,name)
img_name_list = os.listdir(img_dir_path)
for img_name in img_name_list:
img_path = os.path.join(img_dir_path,img_name)
self._img_path_list.append(img_path)
#定义图像的预处理函数
if mode == "train":
self._transform = transforms.Compose([
transforms.RandomResizedCrop(img_size), #缩放图片并随机裁剪图片为指定shape
transforms.RandomHorizontalFlip(0.5), #随机水平翻转图片的概率为0.5
transforms.ToTensor(), #转换图片的格式由HWC ==> CHW
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #图片通道像素的标准化
])
else:
self._transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
def __getitem__(self,index):
"""根据index获取图片数据
"""
#获取图片的路径
img_path = self._img_path_list[index]
#获取图片的标签
img_label = img_path.split("/")[-2]
#将生肖的标签名称转换为数字标签
label_index = self._zodiac_names.index(img_label)
#读取图片
img = Image.open(img_path)
if img.mode != "RGB":
img = img.convert("RGB")
#图片的预处理
img = self._transform(img)
return img,np.array(label_index,dtype=np.int64)
def __len__(self):
"""获取数据集的大小
"""
return len(self._img_path_list)
#加载训练集
train_datasets = ZodiacDatasets(mode="train")
#统计训练集的大小
print(len(train_datasets))
for img,img_label in train_datasets:
print(img.shape,img_label)
break
以上是关于paddlepaddle实现十二生肖的分类之数据的预处理的主要内容,如果未能解决你的问题,请参考以下文章
paddlepaddle十二生肖分类之模型(ResNet)构建
paddlepaddle十二生肖分类之模型(ResNet)构建