【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)相关的知识,希望对你有一定的参考价值。

参考技术A PyTorch是一个非常流行的深度学习框架。但是与其他框架不同的是,PyTorch具有动态执行图,意味着计算图是动态创建的。

如果在 matplotlib 使用上出错时,可加上

会使错误消失,具体我也不知道是为什么,但是百度得到的解决方案,好用就是了。

(1)训练曲线,可以看到测试的损失在一点一点变小

(2)这是在epoch=3时的结果,可以看到准确率已经达到97%,识别手写数字已经几乎没有问题

(3)这是在epoch=8时的结果,可以看到准确率已经达到98%,识别手写数字变得更准确

end~

torchvision.transforms.Compose()详解Pytorch入门手册

简介

torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision.transforms主要是用于常见的一些图形变换。以下是torchvision的构成:

1.torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils: 其他的一些有用的方法。


本文的主题是其中的torchvision.transforms.Compose()类。这个类的主要作用是串联多个图片变换的操作。

from torchvision.transforms import transforms

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),                  # 将输入图片resize成统一尺寸
    transforms.RandomRotation(degrees=(-10, 10)),   # 随机旋转,-10到10度之间随机选
    transforms.RandomHorizontalFlip(p=0.5),         # 随机水平翻转 选择一个概率概率
    transforms.RandomVerticalFlip(p=0.5),           # 随机垂直翻转
    transforms.RandomPerspective(distortion_scale=0.6, p=1.0),    # 随机视角
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),  # 随机选择的高斯模糊模糊图像
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std = [0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

以上是关于【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch实践模型训练(Torchvision)

PyTorch源码解读之torchvision.models(转)

(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解

(机器学习深度学习常用库框架|Pytorch篇)第三节:Pytorch之torchvision详解

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext

小白学习PyTorch教程十七 PyTorch 中 数据集torchvision和torchtext