pytorch实现CIFAR10实战
Posted William_Tao(攻城狮)
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch实现CIFAR10实战相关的知识,希望对你有一定的参考价值。
pytorch实现CIFAR10实战
步骤
代码
训练代码
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from module import *
import torchvision
import torch.nn
#8.引入tensoboard
writer = SummaryWriter('../shizhan')
#1.引入ciff10的测试 和训练集
train_data = torchvision.datasets.CIFAR10(root=r'F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\torchvision\\data',train=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root=r'F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\torchvision\\data',train=False,transform=torchvision.transforms.ToTensor())
train_len = len(train_data)
test_len = len(test_data)
print("训练集的长度为:".format(train_len))
print("测试集的长度为:".format(test_len))
#2.加载数据模型
train_loader = DataLoader(train_data,batch_size=64)
test_loader = DataLoader(test_data,batch_size=64)
# 3.加载网络模型
net = Module()
#4.定义损失函数
loss_fn = nn.CrossEntropyLoss()
#5.定义优化器
learing_rate = 1e-2
optimzer = torch.optim.SGD(net.parameters(),lr=learing_rate)
#6.设置训练模型所需的参数
#记录训练的次数
total_train_step=0
#记录测试的次数
total_test_step=0
#训练的轮数
epoch=10
#7.训练
net.train()
for i in range(epoch):
print("------第轮数训练开始------".format(i+1))
for data in train_loader:
img,target = data
output = net(img)
loss=loss_fn(output,target)
#对模型进行优化
optimzer.zero_grad()
loss.backward()
optimzer.step()
total_train_step=total_train_step+1
if total_train_step % 100 ==0:
print("训练次数,Loss:".format(total_train_step,loss))
#测试步骤开始
net.eval()
total_test_loss=0
total_test_accuracy=0
with torch.no_grad():
for data in test_loader:
img,target = data
output =net(img)
loss = loss_fn(output,target)
total_test_loss+=loss
accuracy = (output.argmax(1)==target).sum() #1代表横向
total_test_accuracy+=accuracy.item()
print("在整个测试集上的损失率为:".format(total_test_loss))
print("在整个测试集上的正确率为:".format(total_test_accuracy/test_len))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.add_scalar("test_accuracy", total_test_accuracy, total_test_step)
total_test_step+=1
#保存每一轮的模型
torch.save(net,r"F:\\研究生\\深度学习项目练习\\b站PyTorch深度学习\\net\\shizhanModelFile\\net_.pth".format(i))
print("模型已保存")
writer.close()
模型代码
import torch.nn
from torch import nn
from torch.nn import Conv2d,Sequential,MaxPool2d,Flatten,Linear
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.model=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model(x)
return x
测试结果
以上是关于pytorch实现CIFAR10实战的主要内容,如果未能解决你的问题,请参考以下文章