【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)相关的知识,希望对你有一定的参考价值。
参考技术A PyTorch是一个非常流行的深度学习框架。但是与其他框架不同的是,PyTorch具有动态执行图,意味着计算图是动态创建的。如果在 matplotlib 使用上出错时,可加上
会使错误消失,具体我也不知道是为什么,但是百度得到的解决方案,好用就是了。
(1)训练曲线,可以看到测试的损失在一点一点变小
(2)这是在epoch=3时的结果,可以看到准确率已经达到97%,识别手写数字已经几乎没有问题
(3)这是在epoch=8时的结果,可以看到准确率已经达到98%,识别手写数字变得更准确
end~
torchvision.transforms.Compose()详解Pytorch入门手册
简介
torchvision
是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision.transforms
主要是用于常见的一些图形变换。以下是torchvision
的构成:
1.torchvision.datasets
: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models
: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms
: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils
: 其他的一些有用的方法。
本文的主题是其中的torchvision.transforms.Compose()
类。这个类的主要作用是串联多个图片变换的操作。
from torchvision.transforms import transforms
train_transforms = transforms.Compose([
transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
transforms.RandomRotation(degrees=(-10, 10)), # 随机旋转,-10到10度之间随机选
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转
transforms.RandomPerspective(distortion_scale=0.6, p=1.0), # 随机视角
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), # 随机选择的高斯模糊模糊图像
transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
mean=[0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
以上是关于【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch源码解读之torchvision.models(转)
(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解
(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解