[基于Pytorch的MNIST识别05]总结
Posted AIplusX
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[基于Pytorch的MNIST识别05]总结相关的知识,希望对你有一定的参考价值。
写在前面
走走停停,我自己做的第一个机器学习的项目:MNIST手写字符集的识别,终于结束了,一路走来也算是踩了大大小小的坑,在这篇文章里做一个总结。
模型配置
模型总共有4层,输入层有784个神经元,分别对应28X28个像素的MNIST手写字符集图像(预先进行归一化),隐藏层有2个,每层16个神经元,输出层有10个神经元,分别对应数字0~9。
模型结构如图所示:
pytorch环境版本:1.8.1+cu102
模型最终在验证集上的准确率是95%:
知识点
1:pytorch模型保存:
save_path = './model/mnist_net'
torch.save(mnist_net, save_path)
2:模型训练技巧
在模型初步训练的时候使用一个较大的batch(例如128)进行初步的训练,大概100个epoch之后验证集上的正确率大概可以达到90%,之后如果还用大batch的话loss下降的就很慢,正确率上不去,所以之后我分别使用小batch(例如32,1)进行训练,最终达到了95%的正确率。
源码
相关的资料关注我的公众号后即可下载,文件结构如图所示:
主要内容
主要内容在我的古月居博客:
[基于Pytorch的MNIST识别05]总结
之后我就会做一些轨迹规划算法的学习以及实现了,欢迎大家持续关注。
以上是关于[基于Pytorch的MNIST识别05]总结的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)
图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)