PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

Posted Xavier Jiezou

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)相关的知识,希望对你有一定的参考价值。

最终成果

http://pytorch-cnn-mnist.herokuapp.com/

GITHUB

https://github.com/XavierJiezou/pytorch-cnn-mnist

本文以最经典的mnist数据集为例,讲述了使用pytorch做机器学习的一整套流程,文中所提到的所有代码都可以到github中查看。

项目场景

简单的学习pytorch、自动求导和神经网络的知识后,我们来练习使用mnist数据集训练一个cnn手写数字识别模型。

导入模块

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

pytorch安装教程可以看这这两篇文章:WindowsLinux

matplotlib库用于绘图,如果没有请通过pip命令安装:

pip install matplotlib

训练设备

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

如果有可用的gpu就在gpu中训练,没有就使用cpu训练。

定义超参数

EPOCH = 10
BATCH_SIZE = 128
LR = 1E-3

EPOCH:训练的轮数。

BATCH_SIZE:数据加载器的批次大小。

LR:优化器的学习率。

下载数据集

train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
test_file = datasets.MNIST(
    root='./dataset/',
    train=False,
    transform=transforms.ToTensor()
)

root:指定数据集的存放路径。

trainTrue表示训练集;False表示测试集。

transform:采用何种方式进行图像转换;transforms.ToTensor()是将形状为(H x W x C)的图片数据转换为形状为(C x H x W)的张量,C为图片通道数,这里是1,也就是灰度图片,HW分别为图片宽高,都是28。然后再把图片的灰度值[0, 255]归一化到[0, 1],也就是除以255

downloadTrue:下载数据集;False:不下载;默认为False,第一次运行代码时设置为True,之后设置为False就行了。

数据可视化

训练数据可视化

train_data = train_file.data
train_targets = train_file.targets
print(train_data.size())  # [60000, 28, 28]
print(train_targets.size())  # [60000]
# visualization
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(train_targets[i].numpy())
    plt.axis('off')
    plt.imshow(train_data[i], cmap='gray')
plt.show()

打印训练数据及标签可知:训练数据是6万张28x28的灰度图片,以及6万个标签。

测试数据可视化

test_file = datasets.MNIST(
    root='./dataset/',
    train=False,
    transform=transforms.ToTensor()
)
test_data = test_file.data
test_targets = test_file.targets
print(test_data.size())  # [10000, 28, 28]
print(test_targets.size())  # [10000]
# visualization
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(test_targets[i].numpy())
    plt.axis('off')
    plt.imshow(test_data[i], cmap='gray')
plt.show()

打印测试数据及标签可知:测试数据是1万张28x28的灰度图片,以及1万个标签。

数据加载器

train_loader = DataLoader(
    dataset=train_file,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_loader = DataLoader(
    dataset=test_file,
    batch_size=BATCH_SIZE,
    shuffle=False
)

dataset:指定数据集。
batch_size:批次的大小。
shuffle: True:打乱数据顺序;False:不打乱数据顺序。一般训练的时候打乱顺序会取得更好的效果,测试的时候不需要打乱顺序。

模型结构

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            # [BATCH_SIZE, 1, 28, 28]
            nn.Conv2d(1, 32, 5, 1, 2),
            # [BATCH_SIZE, 32, 28, 28]
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [BATCH_SIZE, 32, 14, 14]
            nn.Conv2d(32, 64, 5, 1, 2),
            # [BATCH_SIZE, 64, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(2)
            # [BATCH_SIZE, 64, 7, 7]
        )
        self.fc = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)
        return y

创建模型

model = CNN().to(device)
optim = torch.optim.Adam(model.parameters(), LR)
lossf = nn.CrossEntropyLoss()

训练模型

for epoch in range(EPOCH):
    for step, (data, targets) in enumerate(train_loader):
        optim.zero_grad()
        data = data.to(device)
        targets = targets.to(device)
        output = model(data)
        loss = lossf(output, targets)
        loss.backward()
        optim.step()
  1. 梯度清零
  2. 将图片和标签加载到GPUCPU
  3. 将图片传入模型中训练
  4. 根据损失函数计算输出结果和标签的误差
  5. 误差的反向传播
  6. 更新参数

计算损失

loss = 0
total = 0
correct = 0
with torch.no_grad():
    for data, targets in loader:
        data = data.to(device)
        targets = targets.to(device)
        output = model(data)
        loss += lossf(output, targets)
        correct += (output.argmax(1) == targets).sum()
        total += data.size(0)
loss = loss.item()/len(test_loader)
acc = correct.item()/total

计算模型对于整个训练集或测试集的损失准确率

训练过程打印

EPOCH: 01/10 STEP: 469/469 LOSS: 0.0645 ACC: 0.9807 VAL-LOSS: 0.0585 VAL-ACC: 0.9799 TOTAL-TIME: 48
EPOCH: 02/10 STEP: 469/469 LOSS: 0.0405 ACC: 0.9879 VAL-LOSS: 0.0391 VAL-ACC: 0.9870 TOTAL-TIME: 52
EPOCH: 03/10 STEP: 469/469 LOSS: 0.0297 ACC: 0.9913 VAL-LOSS: 0.0320 VAL-ACC: 0.9884 TOTAL-TIME: 46
EPOCH: 04/10 STEP: 469/469 LOSS: 0.0202 ACC: 0.9939 VAL-LOSS: 0.0271 VAL-ACC: 0.9900 TOTAL-TIME: 53
EPOCH: 05/10 STEP: 469/469 LOSS: 0.0192 ACC: 0.9941 VAL-LOSS: 0.0278 VAL-ACC: 0.9900 TOTAL-TIME: 46
EPOCH: 06/10 STEP: 469/469 LOSS: 0.0150 ACC: 0.9956 VAL-LOSS: 0.0294 VAL-ACC: 0.9897 TOTAL-TIME: 47
EPOCH: 07/10 STEP: 469/469 LOSS: 0.0114 ACC: 0.9966 VAL-LOSS: 0.0245 VAL-ACC: 0.9923 TOTAL-TIME: 54
EPOCH: 08/10 STEP: 469/469 LOSS: 0.0115 ACC: 0.9967 VAL-LOSS: 0.0269 VAL-ACC: 0.9906 TOTAL-TIME: 45
EPOCH: 09/10 STEP: 469/469 LOSS: 0.0094 ACC: 0.9972 VAL-LOSS: 0.0278 VAL-ACC: 0.9909 TOTAL-TIME: 47
EPOCH: 10/10 STEP: 469/469 LOSS: 0.0077 ACC: 0.9977 VAL-LOSS: 0.0278 VAL-ACC: 0.9916 TOTAL-TIME: 35

EPOCH:训练的轮数

STEP:数据加载器迭代的次数

LOSS:整个训练集的平均损失

ACC:整个训练集的准确率

VAL-LOSS:整个测试集的平均损失

VAL-ACC:整个测试集的准确率

TOTAL-TIME:训练每轮消耗的时间(单位:秒)。

以上是在英伟达GTX 1050 Ti训练的结果,大概每轮40秒;如果用GTX 1080 Ti每轮大概是20秒。

保存最佳模型

temp = 0
if val_acc > temp:
	torch.save(model.state_dict(), 'model.pt')
	temp = val_acc

根据EPOCH设置的值不同,模型会训练很多轮,为了以后方便加载,我们将最佳的模型保存下来。

首先,我们要明白:什么样的模型是最佳的?当然,对于一个好的模型来说,相比训练集,我们更看重的是它在测试集上的效果。

测试集上有两个指标可以衡量模型的好坏,损失准确率。我们这里就保存准确率最高的。(当然损失最小的也行)

| BEST-MODEL | EPOCH: 07/10 STEP: 469/469 LOSS: 0.0114 ACC: 0.9966 VAL-LOSS: 0.0245 VAL-ACC: 0.9923

就上述10轮的训练结果来看,准确率最高的是第7轮,所以我们需要把第7轮训练的模型保存下来。

加载模型

model = CNN()
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model.eval()

更详细的pytorch保存和加载模型的方法可以看我翻译的这一篇官方文档。

测试模型

打开Windows系统自带的画图应用,创建28x28的画布,背景颜色设置为黑色,前景色为白色,然后随便写几个数字,保存为png或其它格式的图片。

这里我手写了10个黑底白字的数字:

然后加载最佳模型,预测结果如下:(单次预测准确率100%

pred: tensor([0, 1, 2, 3, 4, 5, 6, 7, 2, 9])
true: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

我也手写了10个白底黑字:

测试效果很不理想,毕竟模型是在黑底白字的数据集上训练的嘛:)(单次预测准确率30%)

pred: tensor([3, 2, 2, 3, 2, 5, 5, 2, 2, 2])
true: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

部署模型

我们不可能每次都打开画图软件写一个数字来测试吧?

文章开头的最终成果就是在heroku上成功部署的实例。

pytorch模型线上部署的教程正在制作中,敬请期待……

引用参考

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

以上是关于PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)的主要内容,如果未能解决你的问题,请参考以下文章

AI常用框架和工具丨13. PyTorch实现基于CNN的手写数字识别

AI常用框架和工具丨13. PyTorch实现基于CNN的手写数字识别

用pytorch做手写数字识别,识别l率达97.8%

PyTorch实现用CNN识别手写数字

全网最详细最好懂 PyTorch CNN案例分析 识别手写数字

利用pytorch CNN手写字母识别神经网络模型识别多手写字母(A-Z)