一种轻量化网络实现mnist数据集分类

Posted ZHW_AI课题组

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了一种轻量化网络实现mnist数据集分类相关的知识,希望对你有一定的参考价值。

作者介绍

刘舒婷,女,西安工程大学电子信息学院,2019级硕士研究生,张宏伟人工智能课题组。
研究方向:机器视觉与人工智能。
电子邮箱:913983238@qq.com

1.导入所需的包

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss,BCELoss
from torch.optim import Adam,lr_scheduler
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy

2.下载mnist数据集

网络不好的情况下,会出现报错,重新下载即可。

input_size = 28*28 # MNIST上的图像尺寸是 28x28
output_size = 10  # 类别为 09 的数字,因此为10类

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))])),
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))])),
    batch_size=1000, shuffle=True)

下载完成后,可以看到同目录下会出现一个data文件。也可提前将数据集下载完成后,与需要运行的程序放在同一文件夹下。
在这里插入图片描述

3.构造算法函数

class CNN(nn.Module):
  def __init__(self, input_size, n_feature, output_size):
    # 执行父类的构造函数,所有的网络都要这么写
    super(CNN, self).__init__()
    # 下面是网络里典型结构的一些定义,一般就是卷积和全连接
    # 池化、ReLU一类的不用在这里定义
    self.n_feature = n_feature
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=5)
    self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=5)
    self.fc1 = nn.Linear(n_feature*4*4, 50)
    self.fc2 = nn.Linear(50,10)

  # 下面的 forward 函数,定义了网络的结构,按照一定顺序,把上面构建的一些结构组织起来
  # 意思就是,conv1, conv2 等等的,可以多次重用
  def forward(self, x, verbose=False):
    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2)
    x = x.view(-1, self.n_feature*4*4)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)
    x = F.log_softmax(x, dim=1)
    return x

4.定义训练函数及测试函数

选用交叉熵损失函数,Adam优化器。

loss_f = nn.CrossEntropyLoss()
# 测试函数
def train(model):
  model.train()
  # 从train_loader里,64个样本一个batch为单位提取样本进行训练
  for batch_idx, (data, target) in enumerate(train_loader):
    #data = data.to(device)
    #target = target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = loss_f(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 100 == 0:
      print('Train:[{}/{} ({:.0f}%)]\\tLoss:{:.6f}'.format(batch_idx * len(data),len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
  
def test(model):
  model.eval()
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    # 把数据传入GPU中
    #data, target = data.to(device), target.to(device)
    # 把数据送入模型,得到预测结果
    output = model(data)
    # 计算本次batch的损失,并加入到test_loss中
    '''
    output.max(1, keepdim=True)--->返回每一行中最大的元素并返回索引,返回了两个数组
    output.max(1, keepdim=True)[1] 就是取第二个数组,取索引数组。
    '''
    test_loss += loss_f(output, target).item() #, reduction = 'sum').item()
    # get the index of the max log-probability, 最后一层输出10个数
    # 值最大的那个即对应着分类结果,然后把分类结果保存到pred里
    pred = output.data.max(1, keepdim=True)[1]
    # 将 pred 与 target 相比,得到正确预测结果的数量,并加到 correct 中
    # 这里需要注意一下 view_as ,意思是把 target 变成维度和 pred 一样的意思   
    correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

  test_loss /= len(test_loader.dataset)
  accuracy = 100. * correct / len(test_loader.dataset)
  print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))

5.调用算法函数、训练函数、测试函数。开始运行

n_features = 6
##调用CNN模型
model_cnn = CNN(input_size, n_features, output_size) 
##model_cnn.to(device)
optimizer = optim.Adam(model_cnn.parameters(),lr=0.01)#, momentum=0.5)
print('Number of parameters: {}'.format((model_cnn)))

train(model_cnn) ##调用训练函数
test(model_cnn)  ##调用测试函数

运行结果如下:

在这里插入图片描述

以上是关于一种轻量化网络实现mnist数据集分类的主要内容,如果未能解决你的问题,请参考以下文章

神经网络的学习-搭建神经网络实现mnist数据集分类

基于pytorch平台实现对MNIST数据集的分类分析(前馈神经网络softmax)基础版

如何搭建VGG网络,实现Mnist数据集的图像分类

Pytorch Note25 深层神经网络实现 MNIST 手写数字分类

Pytorch实现RNN网络对MNIST字体分类

Lenet实现mnist数据集分类