torch三层全连接实现手写数字识别

Posted 一只特立独行的猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch三层全连接实现手写数字识别相关的知识,希望对你有一定的参考价值。

利用全连接层将数据维度缩小到【0,9】,用ReLU激活,使用MNIST数据集(28*28,黑色为底色,白色的字的颜色),比较low,不过用来入门不错。

from matplotlib import pyplot as plt
import matplotlib
import torch
import numpy as np

from torchvision import datasets ,transforms
from torch import nn
from torch import optim

#定义转换器,包含操作:将图像转为tensor和(0,1)之间,将数据转换为(-1,1)
transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5,),(0.5))])
#下载数据集
dataset = datasets.MNIST('MNIST_data',download=False,transform=transform)

#加载数据集,定义训练批次为1
trainloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=1)
#定义神经网络模型
model = nn.Sequential(nn.Linear(784,256),
                    nn.ReLU(),
                    nn.Linear(256,64),
                    nn.ReLU(),
                    nn.Linear(64,10),
                    nn.LogSoftmax(dim=1))

criterion=nn.NLLLoss()
#随机梯度下降
optimizer=optim.SGD(model.parameters(),lr=0.005)

#设置迭代次数为8次
epoch=8
for e in range(epoch):
    running_loss=0
    #训练一个批次,image是图像,labels是标签
    for images ,labels in trainloader:
        images=images.view(images.shape[0],-1)
        optimizer.zero_grad()
        output=model.forward(images)
        loss=criterion(output,labels)
        loss.backward()
        optimizer.step()

        running_loss+=loss.item()
    else:
        print(f"第{e}代,训练损失:{running_loss/len(trainloader)}")

#验证集,利用matplotlib画图
images,labels=next(iter(trainloader))
img=images[0].view(1,28*28)

with torch.no_grad():
    logits=model.forward(img)
result=torch.softmax(logits,dim=1)
result=result.data.numpy().squeeze()
fig,(ax1,ax2) = plt.subplots(figsize=(6,9),ncols=2)
ax1.imshow(img.resize_(1,28,28).numpy().squeeze())
ax1.axis("off")
ax2.barh(np.arange(10),result)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
ax2.set_yticklabels(np.arange(10))
ax2.set_xlim(0,1.1)

plt.tight_layout()
plt.show()

#保存模型
torch.save(model,"weight.pth")

以上是关于torch三层全连接实现手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章

三层CNN实现手写数字识别(新手项目)

全连接神经网络实现识别手写数据集MNIST

深度学习--TensorFlow(项目)识别自己的手写数字(基于CNN卷积神经网络)

[Pytorch系列-29]:神经网络基础 - 全连接浅层神经网络实现10分类手写数字识别

torch教程[1]用numpy实现三层全连接神经网络

用pytorch实现多层感知机(MLP)(全连接神经网络FC)分类MNIST手写数字体的识别