torch三层全连接实现手写数字识别
Posted 一只特立独行的猫
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch三层全连接实现手写数字识别相关的知识,希望对你有一定的参考价值。
利用全连接层将数据维度缩小到【0,9】,用ReLU激活,使用MNIST数据集(28*28,黑色为底色,白色的字的颜色),比较low,不过用来入门不错。
from matplotlib import pyplot as plt
import matplotlib
import torch
import numpy as np
from torchvision import datasets ,transforms
from torch import nn
from torch import optim
#定义转换器,包含操作:将图像转为tensor和(0,1)之间,将数据转换为(-1,1)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5))])
#下载数据集
dataset = datasets.MNIST('MNIST_data',download=False,transform=transform)
#加载数据集,定义训练批次为1
trainloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=1)
#定义神经网络模型
model = nn.Sequential(nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1))
criterion=nn.NLLLoss()
#随机梯度下降
optimizer=optim.SGD(model.parameters(),lr=0.005)
#设置迭代次数为8次
epoch=8
for e in range(epoch):
running_loss=0
#训练一个批次,image是图像,labels是标签
for images ,labels in trainloader:
images=images.view(images.shape[0],-1)
optimizer.zero_grad()
output=model.forward(images)
loss=criterion(output,labels)
loss.backward()
optimizer.step()
running_loss+=loss.item()
else:
print(f"第{e}代,训练损失:{running_loss/len(trainloader)}")
#验证集,利用matplotlib画图
images,labels=next(iter(trainloader))
img=images[0].view(1,28*28)
with torch.no_grad():
logits=model.forward(img)
result=torch.softmax(logits,dim=1)
result=result.data.numpy().squeeze()
fig,(ax1,ax2) = plt.subplots(figsize=(6,9),ncols=2)
ax1.imshow(img.resize_(1,28,28).numpy().squeeze())
ax1.axis("off")
ax2.barh(np.arange(10),result)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
ax2.set_yticklabels(np.arange(10))
ax2.set_xlim(0,1.1)
plt.tight_layout()
plt.show()
#保存模型
torch.save(model,"weight.pth")
以上是关于torch三层全连接实现手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章
深度学习--TensorFlow(项目)识别自己的手写数字(基于CNN卷积神经网络)