pytorch:实践MNIST手写数字识别
Posted Z|Star
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch:实践MNIST手写数字识别相关的知识,希望对你有一定的参考价值。
在本专栏第十篇记录过CNN的理论,并大致了解使用CNN+残差网络训练MNIST的方式,由于课件中不包含完整代码,因此想要复现一遍,但遇到各种各样的坑,纸上得来,终觉浅~
第一个问题:MNIST数据集的获取
train_dataset = datasets.MNIST(root='../dataset/mnist',
train=True,
download=True,
transform=transform)
在datasets.MNIST的中可以设置download=True,这样设置,系统会自动在root里面检测MNIST数据文件,如果存在则不下载,如果不存在则自动联网下载。我尝试自动联网下载,结果十几分钟之后,下载一半之后报错,网络出现问题。于是翻阅其它资源,将其手动下载下来添加到minst文件夹中自动创建的raw文件夹中。
(如果你也需要这个数据集,可以在微信公众号“我有一计”内回复“数据集”,即可获取下载链接)
第二个问题:batch_size的大小的选取
回顾一下之前就记录过的三个概念:epoch、 iteration和batchsize
1)batchsize:批大小。在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练;
2)iteration:1个iteration等于使用batchsize个样本训练一次;
3)epoch:1个epoch等于使用训练集中的全部样本训练一次;
GPU对2的幂次的batch可以发挥更佳的性能,因此设置成16、32、64、128时往往要比设置为整10、整100的倍数时表现更优。
在现存允许的情况下batch_size可以取相对大一些
第三个问题:维度匹配
深度学习最麻烦的就是维度匹配,按照课件手打的代码出现维度不匹配的警告,具体原因尚不明朗,先复制别人的代码跑通再说。
可以跑通的源代码:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist',
train=True,
download=True,
transform=transform)
test_dataset = datasets.MNIST(root='../dataset/mnist',
train=False,
download=True,
transform=transform)
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
'''
#查看数据,example_data为图片数据,example_targets为图片标签,图片的shape为32, 1, 28, 28,单通道,28*28的图片
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
#用matplotlib将部分图片显示出来看看
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
'''
# 定义残差块
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.channels = channels
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x + y)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5) # 88 = 24x3 + 16
self.rblock1 = ResidualBlock(16)
self.rblock2 = ResidualBlock(32)
self.mp = nn.MaxPool2d(2)
self.fc = nn.Linear(512, 10)
def forward(self, x):
in_size = x.size(0)
x = self.mp(F.relu(self.conv1(x)))
x = self.rblock1(x)
x = self.mp(F.relu(self.conv2(x)))
x = self.rblock2(x)
x = x.view(in_size, -1)
x = self.fc(x)
return x
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
inputs, target = data
inputs, target = inputs.to(device), target.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, dim=1)
total += target.size(0)
correct += (predicted == target).sum().item()
print('Accuracy on test set: %d' % (100 * correct / total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
# 保存模型
torch.save(model.state_dict(), 'myfirstmodel.pt')
'''模型的加载
model = torch.load(PATH)
model.eval()
'''
最终,模型的在测试集上的准确率在98-99%左右。
以上是关于pytorch:实践MNIST手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章
【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)
PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)