pyTorch使用mnist数据集实现手写数字识别
Posted 1-0001
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pyTorch使用mnist数据集实现手写数字识别相关的知识,希望对你有一定的参考价值。
使用mnist数据集实现手写数字识别是入门必做吧。这里使用pyTorch框架进行简单神经网络的搭建。
首先导入需要的包。
1 import torch 2 import torch.nn as nn 3 import torch.utils.data as Data 4 import torchvision
接下来需要下载mnist数据集。我们创建train_data。使用torchvision.datasets.MNIST进行数据集的下载。
1 train_data = torchvision.datasets.MNIST( 2 root=‘./mnist/‘, #下载到该目录下 3 train=True, #为训练数据 4 transform=torchvision.transforms.ToTensor(), #将其装换为tensor的形式 5 download=True, #第一次设置为true表示下载,下载完成后,将其置成false 6 )
之后将其导入data_loader中,这个数据加载类会自动帮我们进行数据集的切片。
1 train_data = torchvision.datasets.MNIST( 2 root=‘./mnist‘, 3 train=True, 4 transform=torchvision.transforms.ToTensor(), 5 download=False 6 ) 7 train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0) 8 test_data = torchvision.datasets.MNIST( 9 root=‘./mnist‘, 10 train=False, 11 transform=torchvision.transforms.ToTensor(), 12 ) 13 test_loader = Data.DataLoader(dataset=test_data, batch_size=32, shuffle=False, num_workers=0) 14 test_num = len(test_data)
之后开始定义我们的模型,由于minist数据集是灰度图像,并且图片的size都是(28, 28, 1),所以输入图片的时候不需要进行额外的修改。
1 class Net(nn.Module): 2 def __init__(self): 3 super(Net, self).__init__() 4 self.conv1 = nn.Sequential(#(1, 28, 28) 5 nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),#(16, 28, 28) 6 nn.ReLU(),#(16, 28, 28) 7 nn.MaxPool2d(kernel_size=2)#(16, 14, 14) 8 ) 9 self.conv2 = nn.Sequential( 10 nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),#(32, 14, 14) 11 nn.ReLU(),#(32, 14, 14) 12 nn.MaxPool2d(kernel_size=2)#(32, 7, 7) 13 ) 14 self.fc = nn.Linear(32 * 7 * 7, 10) 15 def forward(self, x): 16 x = self.conv1(x) 17 x = self.conv2(x) 18 x = x.view(x.size(0), -1) 19 x = self.fc(x) 20 return x
特别注意在最后传入全连接层时,最好自己将x的size改变以确保不会因为自适应而造成错误。因为在传入全连接层时会默认压缩成二维,例如[1, 2, 3, 4]会被压缩成[1*2, 3*4]。
之后开始训练。
1 net = Net() 2 loss_fn = nn.CrossEntropyLoss() 3 optim = torch.optim.Adam(net.parameters(), lr = 0.001) 4 5 save_path = ‘./mnist.pth‘ 6 best_acc = 0.0 7 for epoch in range(3): 8 9 net.train() 10 running_loss = 0.0 11 for step, data in enumerate(train_loader, start=0): 12 images, labels = data 13 optim.zero_grad() 14 logits = net(images) 15 loss = loss_fn(logits, labels) 16 loss.backward() 17 optim.step() 18 19 20 running_loss += loss.item() 21 rate = (step+1)/len(train_loader) 22 a = "*" * int(rate * 50) 23 b = "." * int((1 - rate) * 50) 24 print(" train loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="") 25 print() 26 27 net.eval() 28 acc = 0.0 29 with torch.no_grad(): 30 for data_test in test_loader: 31 test_images, test_labels = data_test 32 outputs = net(test_images) 33 predict_y = torch.max(outputs, dim=1)[1]#torch.max返回两个数值,一个是最大值,一个是最大值的下标 34 acc += (predict_y == test_labels).sum().item() 35 test_accurate = acc / test_num 36 if test_accurate > best_acc: 37 best_acc = test_accurate 38 torch.save(net.state_dict(), save_path) 39 print(‘[epoch %d] train_loss: %.3f test_accuracy: %.3f‘ % 40 (epoch + 1, running_loss / step, test_accurate)) 41 42 print(‘Finished Training‘)
在完成训练后,训练的权重会保存在所设置路径下的文件中,进行预测的时候,建立模型,载入权重,照一张数字的图片,对其进行裁剪,灰度等操作之后加载入模型进行预测。
1 from PIL import Image 2 import matplotlib.pyplot as plt 3 from torchvision import transforms 4 import torch 5 from model import Net 6 7 img = Image.open("./YLY2@}8UMGLW37S$)NCVZ23.png") 8 9 plt.imshow(img) 10 11 # [N, C, H, W] 12 13 train_transform = transforms.Compose([ 14 transforms.Grayscale(), 15 transforms.Resize((28, 28)), 16 transforms.ToTensor(), 17 ]) 18 19 img = train_transform(img) 20 # expand batch dimension 21 img = torch.unsqueeze(img, dim=0) 22 23 # create model 24 model = Net() 25 # load model weights 26 model_weight_path = "./mnist.pth" 27 model.load_state_dict(torch.load(model_weight_path)) 28 29 index_to_class = [‘0‘, ‘1‘, ‘2‘, ‘3‘, ‘4‘, ‘5‘, ‘6‘, ‘7‘, ‘8‘, ‘9‘] 30 31 32 model.eval() 33 with torch.no_grad(): 34 # predict class 35 y = model(img) 36 #print(y.size()) 37 output = torch.squeeze(y) 38 #print(output) 39 predict = torch.softmax(output, dim=0) 40 #print(predict) 41 predict_cla = torch.argmax(predict).numpy() 42 #print(predict_cla) 43 print(index_to_class[predict_cla], predict[predict_cla].numpy()) 44 plt.show()
需要注意的是,载入模型的图片必须多一个维度batch,所以我们用img = torch.unsqueeze(img, dim=0)在图片的开头增加一个batch维度。
之后载入图片,得到输出,将输出的batch维度压缩掉,使用softmax函数得到概率分布,再用argmax函数得到最大值的下标,打印最大值所对应的类别及其概率。
以上是关于pyTorch使用mnist数据集实现手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章
图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)
PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)
pytorch学习实战第五篇:卷积神经网络实现MNIST手写数字识别