PytorchCIFAR-10分类任务
Posted shuimuqingyang
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PytorchCIFAR-10分类任务相关的知识,希望对你有一定的参考价值。
https://blog.csdn.net/weixin_39837402/article/details/81054106
【Pytorch】CIFAR-10分类任务
CIFAR-10数据集共有60000张32*32彩色图片,分为10类,每类有6000张图片。其中50000张用于训练,构成5个训练batch,每一批次10000张图片,其余10000张图片用于测试。
CIFAR-10数据集下载地址:点击下载
数据读取,这里选择下载python版本的数据集,解压后得到如下文件:
其中data_batch_1~data_batch_5为训练集的5个批次,test_batch为测试集。
这些文件是python的序列化模型,这里使用python3,可以使用pickle模块读取这些数据:
-
def unpickle(file):
-
import pickle
-
with open(file, ‘rb‘) as fo:
-
dict = pickle.load(fo, encoding=‘bytes‘)
-
return dict
每一个batch文件包括一个字典,字典的元素是:
data:一个尺寸为10000*3072,数据格式为uint8的numpy array,每
一行数据存储了一张32*32彩色图片的数据,前1024位是图像的红色
通道数据,接着是绿色通道和蓝色通道。
label:一个包含10000个0-9数字的列表,对应data里每张图片的标签。
此外,数据集中还有一个batches.meta文件,它保存了一个python字典,
该字典对标签的10个数字0-9所代表的意义做了解释,比如0代表airplane,
1代表automobile。
这次使用Pytorch框架来进行实验,总体流程是,建立网络(这次小demo用Lenet),自定义数据集读取框架,虽然pytorch已经有关于cifar10的Dataset实例,但还是自己实现了一遍,接着用DataLoader分批读取数据集,定义损失函数和优化器,进行批次训练。
-
import torch
-
import torchvision
-
from torch.autograd import Variable
-
import torch.nn as nn
-
import torch.nn.functional as F
-
import torch.optim as optim
-
import torch.utils.data as Data
-
import torchvision.transforms as transforms
-
import numpy as np
-
from PIL import Image
-
import matplotlib.pyplot as plt
-
-
#预设参数
-
CLASS_NUM = 10
-
BATCH_SIZE = 128
-
EPOCH = 30
-
-
#Lenet网络代码
-
class Lenet(nn.Module):
-
def __init__(self):
-
super(Lenet,self).__init__()
-
#定义网络层
-
#入通道数,出通道数,卷积尺寸
-
self.conv1 = nn.Conv2d(3,6,5)
-
self.conv2 = nn.Conv2d(6,16,5)
-
self.fc1 = nn.Linear(16*5*5,120)
-
self.fc2 = nn.Linear(120,84)
-
self.fc3 = nn.Linear(84,10)
-
-
#将二维数据展开成一维数据以输入到全连接层
-
def num_flat_features(self,x):
-
#size为[batch_size,num_channels,height,width]
-
#除去batch_size,num_channels*height*width就是展开后维度
-
size = x.size()[1:]
-
num_features = 1
-
for s in size:
-
num_features = num_features*s
-
return num_features
-
-
def forward(self,x):
-
#定义前向传播
-
#输入 和 窗口尺寸
-
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
-
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
-
x = x.view(-1, self.num_flat_features(x))
-
x = F.relu(self.fc1(x))
-
x = F.relu(self.fc2(x))
-
x = self.fc3(x)
-
return x
-
-
def unpickle(file):
-
import pickle
-
with open(file, ‘rb‘) as fo:
-
dict = pickle.load(fo, encoding=‘bytes‘)
-
return dict
-
-
#从源文件读取数据
-
#返回 train_data[50000,3072]和labels[50000]
-
# test_data[10000,3072]和labels[10000]
-
def get_data(train=False):
-
data = None
-
labels = None
-
if train == True:
-
for i in range(1,6):
-
batch = unpickle(‘data/cifar-10-batches-py/data_batch_‘+str(i))
-
if i == 1:
-
data = batch[b‘data‘]
-
else:
-
data = np.concatenate([data,batch[b‘data‘]])
-
-
if i == 1:
-
labels = batch[b‘labels‘]
-
else:
-
labels = np.concatenate([labels,batch[b‘labels‘]])
-
else:
-
batch = unpickle(‘data/cifar-10-batches-py/test_batch‘)
-
data = batch[b‘data‘]
-
labels = batch[b‘labels‘]
-
return data,labels
-
-
#图像预处理函数,Compose会将多个transform操作包在一起
-
#对于彩色图像,色彩通道不存在平稳特性
-
transform = transforms.Compose([
-
# ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
-
# 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
-
transforms.ToTensor(),
-
#Normalize函数将图像数据归一化到[-1,1]
-
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
-
]
-
)
-
-
#将标签转换为torch.LongTensor
-
def target_transform(label):
-
label = np.array(label)
-
target = torch.from_numpy(label).long()
-
return target
-
-
‘‘‘
-
自定义数据集读取框架来载入cifar10数据集
-
需要继承data.Dataset
-
‘‘‘
-
class Cifar10_Dataset(Data.Dataset):
-
def __init__(self,train=True,transform=None,target_transform=None):
-
#初始化文件路径
-
self.transform = transform
-
self.target_transform = target_transform
-
self.train = train
-
#载入训练数据集
-
if self.train:
-
self.train_data,self.train_labels = get_data(train)
-
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
-
# 将图像数据格式转换为[height,width,channels]方便预处理
-
self.train_data = self.train_data.transpose((0, 2, 3, 1))
-
#载入测试数据集
-
else:
-
self.test_data,self.test_labels = get_data()
-
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
-
self.test_data = self.test_data.transpose((0, 2, 3, 1))
-
pass
-
def __getitem__(self, index):
-
#从数据集中读取一个数据并对数据进行
-
#预处理返回一个数据对,如(data,label)
-
if self.train:
-
img, label = self.train_data[index], self.train_labels[index]
-
else:
-
img, label = self.test_data[index], self.test_labels[index]
-
-
img = Image.fromarray(img)
-
#图像预处理
-
if self.transform is not None:
-
img = self.transform(img)
-
#标签预处理
-
if self.target_transform is not None:
-
target = self.target_transform(label)
-
-
return img, target
-
def __len__(self):
-
#返回数据集的size
-
if self.train:
-
return len(self.train_data)
-
else:
-
return len(self.test_data)
-
-
if __name__ == ‘__main__‘:
-
#读取训练集和测试集
-
train_data = Cifar10_Dataset(True,transform,target_transform)
-
print(‘size of train_data:{}‘.format(train_data.__len__()))
-
test_data = Cifar10_Dataset(False,transform,target_transform)
-
print(‘size of test_data:{}‘.format(test_data.__len__()))
-
train_loader = Data.DataLoader(dataset=train_data, batch_size = BATCH_SIZE, shuffle=True)
-
-
net = Lenet()
-
optimizer = optim.Adam(net.parameters(), lr = 0.001, betas=(0.9, 0.99))
-
#在使用CrossEntropyLoss时target直接使用类别索引,不适用one-hot
-
loss_fn = nn.CrossEntropyLoss()
-
-
loss_list = []
-
for epoch in range(1,EPOCH+1):
-
#训练部分
-
for step,(x,y) in enumerate(train_loader):
-
b_x = Variable(x)
-
b_y = Variable(y)
-
output = net(b_x)
-
loss = loss_fn(output,b_y)
-
optimizer.zero_grad()
-
loss.backward()
-
optimizer.step()
-
#记录loss
-
if step%50 == 0:
-
loss_list.append(loss)
-
#每完成一个epoch进行一次测试观察效果
-
pre_correct = 0.0
-
test_loader = Data.DataLoader(dataset=test_data, batch_size = 100, shuffle=True)
-
for (x,y) in (test_loader):
-
b_x = Variable(x)
-
b_y = Variable(y)
-
output = net(b_x)
-
pre = torch.max(output,1)[1]
-
pre_correct = pre_correct+float(torch.sum(pre==b_y))
-
-
print(‘EPOCH:{epoch},ACC:{acc}%‘.format(epoch=epoch,acc=(pre_correct/float(10000))*100))
-
-
#保存网络模型
-
torch.save(net,‘lenet_cifar_10.model‘)
-
#绘制loss变化曲线
-
plt.plot(loss_list)
-
plt.show()
第一个pytorch demo跑通了,但是训练模型效果很不好,应该是Lenet作用于Cifar10有些过于力不从心了,刚开始接触深度学习的图像领域还不怎么懂,下次换一个更强大的网络。
以上是关于PytorchCIFAR-10分类任务的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch CIFAR10图像分类 EfficientNet v1篇
Pytorch CIFAR10图像分类 EfficientNet v1篇