用pytorch实现多层感知机(MLP)(全连接神经网络FC)分类MNIST手写数字体的识别
Posted yaowuyangwei521
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用pytorch实现多层感知机(MLP)(全连接神经网络FC)分类MNIST手写数字体的识别相关的知识,希望对你有一定的参考价值。
1.导入必备的包
1 import torch 2 import numpy as np 3 from torchvision.datasets import mnist 4 from torch import nn 5 from torch.autograd import Variable 6 import matplotlib.pyplot as plt 7 import torch.nn.functional as F 8 from torch.utils.data import DataLoader 9 %matplotlib inline
2.定义mnist数据的格式变换
1 def data_transform(x): 2 x = np.array(x, dtype = ‘float32‘) / 255 3 x = (x - 0.5) /0.5 4 x = x.reshape((-1, )) 5 x = torch.from_numpy(x) 6 return x
3.下载数据集,定义数据迭代器
1 trainset = mnist.MNIST(‘./dataset/mnist‘, train=True, transform=data_transform, download=True) 2 testset = mnist.MNIST(‘./dataset/mnist‘, train = False, transform=data_transform, download=True)】 3 train_data = DataLoader(trainset, batch_size=64, shuffle=True) 4 test_data = DataLoader(testset, batch_size=128, shuffle=False)
4.定义全连接神经网络(多层感知机)
1 class MLP(nn.Module): 2 def __init__(self): 3 super(MLP, self).__init__() 4 self.fc1 = nn.Linear(28*28, 500) 5 self.fc2 = nn.Linear(500, 250) 6 self.fc3 = nn.Linear(250, 125) 7 self.fc4 = nn.Linear(125, 10) 8 9 def forward(self, x): 10 x = F.relu(self.fc1(x)) 11 x = F.relu(self.fc2(x)) 12 x = F.relu(self.fc3(x)) 13 x = self.fc4(x) 14 return x 15 16 mlp = MLP()
5.定义损失函数和优化器
1 criterion = nn.CrossEntropyLoss() 2 optimizer = torch.optim.SGD(mlp.parameters(), 1e-3)
6.开始训练和测试
1 losses = [] 2 acces = [] 3 eval_losses = [] 4 eval_acces = [] 5 6 for e in range(20): 7 train_loss = 0 8 train_acc = 0 9 mlp.train() 10 for im, label in train_data: 11 im = Variable(im) 12 label = Variable(label) 13 # 前向传播 14 out = mlp(im) 15 loss = criterion(out, label) 16 # 反向传播 17 optimizer.zero_grad() 18 loss.backward() 19 optimizer.step() 20 # 记录误差 21 train_loss += loss.item() 22 # 计算分类的准确率 23 _, pred = out.max(1) 24 num_correct = (pred == label).sum().item() 25 acc = num_correct / im.shape[0] 26 train_acc += acc 27 28 losses.append(train_loss / len(train_data)) 29 acces.append(train_acc / len(train_data)) 30 # 在测试集上检验效果 31 eval_loss = 0 32 eval_acc = 0 33 mlp.eval() # 将模型改为预测模式 34 for im, label in test_data: 35 im = Variable(im) 36 label = Variable(label) 37 out = mlp(im) 38 loss = criterion(out, label) 39 # 记录误差 40 eval_loss += loss.item() 41 # 记录准确率 42 _, pred = out.max(1) 43 num_correct = (pred == label).sum().item() 44 acc = num_correct / im.shape[0] 45 eval_acc += acc 46 47 eval_losses.append(eval_loss / len(test_data)) 48 eval_acces.append(eval_acc / len(test_data)) 49 print(‘epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}‘ 50 .format(e, train_loss / len(train_data), train_acc / len(train_data), 51 eval_loss / len(test_data), eval_acc / len(test_data)))
7.测试结果
epoch: 0, Train Loss: 2.287240, Train Acc: 0.124150, Eval Loss: 2.265074, Eval Acc: 0.237540 epoch: 1, Train Loss: 2.237043, Train Acc: 0.385861, Eval Loss: 2.197773, Eval Acc: 0.524921 epoch: 2, Train Loss: 2.138911, Train Acc: 0.555487, Eval Loss: 2.050214, Eval Acc: 0.554292 epoch: 3, Train Loss: 1.901877, Train Acc: 0.563833, Eval Loss: 1.688784, Eval Acc: 0.592662 epoch: 4, Train Loss: 1.439467, Train Acc: 0.625483, Eval Loss: 1.178063, Eval Acc: 0.704905 epoch: 5, Train Loss: 1.022494, Train Acc: 0.745586, Eval Loss: 0.869467, Eval Acc: 0.778184 epoch: 6, Train Loss: 0.795575, Train Acc: 0.790528, Eval Loss: 0.702586, Eval Acc: 0.808347 epoch: 7, Train Loss: 0.665018, Train Acc: 0.816031, Eval Loss: 0.601074, Eval Acc: 0.831586 epoch: 8, Train Loss: 0.583082, Train Acc: 0.834588, Eval Loss: 0.535897, Eval Acc: 0.843750 epoch: 9, Train Loss: 0.527930, Train Acc: 0.848231, Eval Loss: 0.490443, Eval Acc: 0.857694 epoch: 10, Train Loss: 0.488764, Train Acc: 0.858925, Eval Loss: 0.456138, Eval Acc: 0.866396 epoch: 11, Train Loss: 0.459293, Train Acc: 0.868220, Eval Loss: 0.430784, Eval Acc: 0.873220 epoch: 12, Train Loss: 0.436398, Train Acc: 0.874117, Eval Loss: 0.413343, Eval Acc: 0.875890 epoch: 13, Train Loss: 0.418043, Train Acc: 0.880031, Eval Loss: 0.396967, Eval Acc: 0.880340 epoch: 14, Train Loss: 0.403195, Train Acc: 0.884029, Eval Loss: 0.385431, Eval Acc: 0.885483 epoch: 15, Train Loss: 0.390613, Train Acc: 0.887327, Eval Loss: 0.372552, Eval Acc: 0.889537 epoch: 16, Train Loss: 0.379947, Train Acc: 0.890275, Eval Loss: 0.363168, Eval Acc: 0.891812 epoch: 17, Train Loss: 0.370701, Train Acc: 0.893557, Eval Loss: 0.355597, Eval Acc: 0.894482 epoch: 18, Train Loss: 0.362498, Train Acc: 0.896572, Eval Loss: 0.348329, Eval Acc: 0.897844 epoch: 19, Train Loss: 0.354748, Train Acc: 0.898121, Eval Loss: 0.340272, Eval Acc: 0.899921
8.训练损失和训练精度曲线
1 plt.title(‘train loss‘) 2 plt.plot(np.arange(len(losses)), losses)
1 plt.plot(np.arange(len(acces)), acces) 2 plt.title(‘train acc‘)
以上是关于用pytorch实现多层感知机(MLP)(全连接神经网络FC)分类MNIST手写数字体的识别的主要内容,如果未能解决你的问题,请参考以下文章