pytorch:实践MNIST手写数字识别

Posted Z|Star

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch:实践MNIST手写数字识别相关的知识,希望对你有一定的参考价值。

在本专栏第十篇记录过CNN的理论,并大致了解使用CNN+残差网络训练MNIST的方式,由于课件中不包含完整代码,因此想要复现一遍,但遇到各种各样的坑,纸上得来,终觉浅~

第一个问题:MNIST数据集的获取

train_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=True,
                               download=True,
                               transform=transform)

在datasets.MNIST的中可以设置download=True,这样设置,系统会自动在root里面检测MNIST数据文件,如果存在则不下载,如果不存在则自动联网下载。我尝试自动联网下载,结果十几分钟之后,下载一半之后报错,网络出现问题。于是翻阅其它资源,将其手动下载下来添加到minst文件夹中自动创建的raw文件夹中。
(如果你也需要这个数据集,可以在微信公众号“我有一计”内回复“数据集”,即可获取下载链接)

第二个问题:batch_size的大小的选取
回顾一下之前就记录过的三个概念:epoch、 iteration和batchsize
1)batchsize:批大小。在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练;
2)iteration:1个iteration等于使用batchsize个样本训练一次;
3)epoch:1个epoch等于使用训练集中的全部样本训练一次;
GPU对2的幂次的batch可以发挥更佳的性能,因此设置成16、32、64、128时往往要比设置为整10、整100的倍数时表现更优。
在现存允许的情况下batch_size可以取相对大一些

第三个问题:维度匹配
深度学习最麻烦的就是维度匹配,按照课件手打的代码出现维度不匹配的警告,具体原因尚不明朗,先复制别人的代码跑通再说。

可以跑通的源代码:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt

batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=True,
                               download=True,
                               transform=transform)
test_dataset = datasets.MNIST(root='../dataset/mnist',
                              train=False,
                              download=True,
                              transform=transform)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False)

'''
#查看数据,example_data为图片数据,example_targets为图片标签,图片的shape为32, 1, 28, 28,单通道,28*28的图片
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)

#用matplotlib将部分图片显示出来看看
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()
'''


# 定义残差块
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.channels = channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x + y)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)  # 88 = 24x3 + 16

        self.rblock1 = ResidualBlock(16)
        self.rblock2 = ResidualBlock(32)

        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        in_size = x.size(0)

        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)

        x = x.view(in_size, -1)
        x = self.fc(x)
        return x


model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))
            running_loss = 0.0


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
        print('Accuracy on test set: %d' % (100 * correct / total))


if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()
        # 保存模型
        torch.save(model.state_dict(), 'myfirstmodel.pt')

'''模型的加载
model = torch.load(PATH)
model.eval()
'''

最终,模型的在测试集上的准确率在98-99%左右。

以上是关于pytorch:实践MNIST手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章

【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)

PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

pyTorch使用mnist数据集实现手写数字识别

手写数字识别——基于全连接层和MNIST数据集

pytorch学习实战第五篇:卷积神经网络实现MNIST手写数字识别