用pytorch做手写数字识别,识别l率达97.8%

Posted liuxinyu12378

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用pytorch做手写数字识别,识别l率达97.8%相关的知识,希望对你有一定的参考价值。

pytorch做手写数字识别

效果如下:

技术图片

 

工程目录如下

技术图片

 

第一步  数据获取

下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下

from torchvision.datasets import MNIST
import torchvision

mnist = MNIST(root=‘./data‘,train=True,download=True)

print(mnist)
print(mnist[0])
print(len(mnist))
img = mnist[0][0]
img.show()

  

dataset.py文件,读取数据并做预处理

 

‘‘‘
准备数据集
‘‘‘

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision


def mnist_dataset(train):

    func = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.1307,),std=(0.3081,))
    ])

    #1.准备Mnist数据集
    return MNIST(root=‘./data‘,train=train,download=False,transform=func)

def get_dataloader(train = True):
    mnist = mnist_dataset(train)
    return DataLoader(mnist,batch_size=128,shuffle=True)

if __name__ == ‘__main__‘:
    for (images,labels) in get_dataloader():
        print(images.size())
        print(labels.size())
        break

 

  

 

models.py文件,定义训练的模型类

‘‘‘
定义模型
‘‘‘

import torch.nn as  nn
import torch.nn.functional as F

class MnistModel(nn.Module):

    def __init__(self):
        super(MnistModel,self).__init__()
        self.fc1 = nn.Linear(1*28*28,100)
        self.fc2 = nn.Linear(100,10)

    def forward(self,image):
        image_viewd = image.view(-1,1*28*28) #[batch_size,1*28*28]
        fc1_out = self.fc1(image_viewd) #[batch_size,100]
        fc1_out_relu = F.relu(fc1_out) #[batch_size,100]
        out = self.fc2(fc1_out_relu) #[batch_size,10]

        return F.log_softmax(out,dim=-1)  #带权损失计算交叉熵

 

cong.py文件,定义一些常亮,设置使用cpu还是GPU  

‘‘‘
项目配置
‘‘‘

import torch

train_batch_size = 128
test_batch_size = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  

train.py文件,模型训练文件,保存模型

"""
进行模型的训练
"""
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
import numpy as np
import torch
import os
from test import eval

#1. 实例化模型,优化器,损失函数
model = MnistModel().to(conf.device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)

#2. 进行循环,进行训练
def train(epoch):
    train_dataloader = get_dataloader(train=True)
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
    total_loss = []
    for idx,(input,target) in bar:
        input = input.to(conf.device)
        target = target.to(conf.device)
        #梯度置为0
        optimizer.zero_grad()
        #计算得到预测值
        output = model(input)
        #得到损失
        loss = F.nll_loss(output,target)
        #反向传播,计算损失
        loss.backward()
        total_loss.append(loss.item())
        #参数的更新
        optimizer.step()
        #打印数据
        if idx%10 ==0 :
            bar.set_description_str("epcoh: idx:,loss::.6f".format(epoch,idx,np.mean(total_loss)))
            torch.save(model.state_dict(),"./models/model.pkl")
            torch.save(optimizer.state_dict(),"./models/optimizer.pkl")

if __name__ == ‘__main__‘:
    for i in range(10):
        train(i)
        eval()

 

test.py文件,模型测试文件,测试模型准确率  

‘‘‘
进行模型评估
‘‘‘

from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
import numpy as np
import torch
import os

def eval():
    #实例化模型,优化器,损失函数
    model = MnistModel().to(conf.device)

    if os.path.exists("./models/model.pkl"):
        model.load_state_dict(torch.load("./models/model.pkl"))

    test_dataloader = get_dataloader(train=False)
    total_loss = []
    total_acc = []
    with torch.no_grad():
        for input, target in test_dataloader:  # 2. 进行循环,进行训练
            input = input.to(conf.device)
            target = target.to(conf.device)
            # 计算得到预测值
            output = model(input)
            # 得到损失
            loss = F.nll_loss(output, target)
            # 反向传播,计算损失
            total_loss.append(loss.item())

            # 计算准确率
            ###计算预测值
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())
    print("test loss:,test acc:".format(np.mean(total_loss), np.mean(total_acc)))

# if __name__ == ‘__main__‘:
#     # for i in range(10):
#     #     train(i)
#     eval()

  

 

以上是关于用pytorch做手写数字识别,识别l率达97.8%的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch实现用CNN识别手写数字

用PyTorch构建基于卷积神经网络的手写数字识别模型

使用卷积网络做手写数字识别

PyTorch实现手写数字识别

PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

使用循环神经网络做手写数字识别