[基于Pytorch的MNIST识别05]总结

Posted AIplusX

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[基于Pytorch的MNIST识别05]总结相关的知识,希望对你有一定的参考价值。

写在前面

走走停停,我自己做的第一个机器学习的项目:MNIST手写字符集的识别,终于结束了,一路走来也算是踩了大大小小的坑,在这篇文章里做一个总结。

模型配置

模型总共有4层,输入层有784个神经元,分别对应28X28个像素的MNIST手写字符集图像(预先进行归一化),隐藏层有2个,每层16个神经元,输出层有10个神经元,分别对应数字0~9。

模型结构如图所示:

pytorch环境版本:1.8.1+cu102

模型最终在验证集上的准确率是95%:

知识点

1:pytorch模型保存:

save_path = './model/mnist_net'
torch.save(mnist_net, save_path) 

2:模型训练技巧

在模型初步训练的时候使用一个较大的batch(例如128)进行初步的训练,大概100个epoch之后验证集上的正确率大概可以达到90%,之后如果还用大batch的话loss下降的就很慢,正确率上不去,所以之后我分别使用小batch(例如32,1)进行训练,最终达到了95%的正确率。

源码

相关的资料关注我的公众号后即可下载,文件结构如图所示:

主要内容

主要内容在我的古月居博客:
[基于Pytorch的MNIST识别05]总结

之后我就会做一些轨迹规划算法的学习以及实现了,欢迎大家持续关注。

以上是关于[基于Pytorch的MNIST识别05]总结的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch基于 LSTM 的手写数字识别(MNIST)

基于PyTorch实现MNIST手写字识别

PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)

[基于Pytorch的MNIST识别01]神经网络建立