pytorch实现CIFAR10实战

Posted William_Tao(攻城狮)

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch实现CIFAR10实战相关的知识,希望对你有一定的参考价值。

pytorch实现CIFAR10实战

步骤

代码

训练代码

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from  module import *
import  torchvision
import  torch.nn
#8.引入tensoboard
writer = SummaryWriter('../shizhan')
#1.引入ciff10的测试 和训练集
train_data = torchvision.datasets.CIFAR10(root=r'F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\torchvision\\data',train=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root=r'F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\torchvision\\data',train=False,transform=torchvision.transforms.ToTensor())

train_len = len(train_data)
test_len = len(test_data)
print("训练集的长度为:".format(train_len))
print("测试集的长度为:".format(test_len))
#2.加载数据模型
train_loader = DataLoader(train_data,batch_size=64)
test_loader = DataLoader(test_data,batch_size=64)
# 3.加载网络模型
net = Module()

#4.定义损失函数
loss_fn = nn.CrossEntropyLoss()
#5.定义优化器
learing_rate = 1e-2
optimzer = torch.optim.SGD(net.parameters(),lr=learing_rate)

#6.设置训练模型所需的参数
#记录训练的次数
total_train_step=0
#记录测试的次数
total_test_step=0
#训练的轮数
epoch=10

#7.训练
net.train()
for i in range(epoch):
    print("------第轮数训练开始------".format(i+1))
    for data in train_loader:
        img,target = data
        output = net(img)
        loss=loss_fn(output,target)
        #对模型进行优化
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()

        total_train_step=total_train_step+1
        if total_train_step %  100 ==0:
            print("训练次数,Loss:".format(total_train_step,loss))
    #测试步骤开始
    net.eval()
    total_test_loss=0
    total_test_accuracy=0
    with torch.no_grad():
        for data in test_loader:
            img,target = data
            output =net(img)
            loss = loss_fn(output,target)
            total_test_loss+=loss
            accuracy = (output.argmax(1)==target).sum()  #1代表横向
            total_test_accuracy+=accuracy.item()
        print("在整个测试集上的损失率为:".format(total_test_loss))
        print("在整个测试集上的正确率为:".format(total_test_accuracy/test_len))
        writer.add_scalar("test_loss",total_test_loss,total_test_step)
        writer.add_scalar("test_accuracy", total_test_accuracy, total_test_step)
        total_test_step+=1
        #保存每一轮的模型
        torch.save(net,r"F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\net\\shizhanModelFile\\net_.pth".format(i))
        print("模型已保存")
    writer.close()


模型代码

import torch.nn
from torch import nn
from torch.nn import Conv2d,Sequential,MaxPool2d,Flatten,Linear
class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.model=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
    def forward(self,x):
        x=self.model(x)
        return x

测试结果


以上是关于pytorch实现CIFAR10实战的主要内容,如果未能解决你的问题,请参考以下文章

pytorch实战学习第六篇:CIFAR-10分类实现

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNet篇

Pytorch CIFAR10图像分类 DenseNet篇

Pytorch CIFAR10图像分类 GoogLeNet篇