基于循环神经网络的手写数字图像识别方法

Posted 统计家园

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于循环神经网络的手写数字图像识别方法相关的知识,希望对你有一定的参考价值。

在普通的全连接网络或卷积神经网络(CNN)中,每层神经元的信号只能向上一层传播,样本的处理在各个时刻相互独立,即没有考虑到人类视觉神经或听觉神经所接收序列的连续性。因此,专门用于处理序列的循环神经网络RNN(Recurrent Neural Network)诞生了。RNN模型的每一个神经元除了当前信息的输入外,还会保留之前产生的记忆信息。RNN可以用来处理连续的语音、连续的手写字等。

一、循环神经网络基本理论
对于常见的序列(语音或手写字等)长短不一,难以拆分成一个个独立的样本来训练。RNN模型一次只传输固定数量的输入样本给模型,可以分多次传递,传递次数根据数据而定。因此,RNN模型善于处理不定长度的输入。RNN模型假设样本是基于序列产生的。 假设一个序列其索引是从0到T,对于该序列中任意的索引号t对应的输入是样本x中的第t个元素x (t) ,模型在序列索引号t位置的隐藏状态h (t) 是由x (t) 和在t-1位置的隐藏状态h (t-1) 共同决定的,模型在t时刻的输出o (t) 是由h (t) 通过非线性变换得到的。
二、循环神经网络模型
RNN的基本结构是由输入层、隐藏层和输出层构成,与CNN和DNN不同的是RNN在隐藏层中将上次输出结果也作为输入,用数学形式来表示为:h t =f w (h t-1 x t ),RNN的结构如图1所示。

图1 RNN模型结构

如图1所示,在隐藏层与隐藏层之间的循环连接由权重矩阵W参数化,表示当前时间与前一期之间的信息传递;输入层与隐藏层的循环连接由权重U参数化,表示当前时间的外部输入信号到隐藏单元的转换;隐藏层与输出层之间的循环连接由权重矩阵V参数化,表示当前时间的隐藏单元到输出层之间的转换;损失L衡量输出O和训练目标Y的距离。

其中U、V、W三个矩阵在整个RNN网络中是共享的,其原因有三,一是RNN对输入序列长度不可预知,为了实现可变长度需要权重共享。如果每一步都有不同的权重,那么不同长度的输入权重的个数也会不一样,不但不能泛化到训练时没有见过的序列长度,也不能在时间上共享不同序列长度和不同位置的统计强度。二是权重共享可以减少参数数量,节省内存和运算时间。三是权重共享可以保证每个时间下的输入和隐藏层都来自同一种变换方式,即假设每一步的信息都是平等的。

(一)前向传播算法

对于t时刻有h(t)=fw(W*h(t-1)+U*x(t)+b),其中,fw为激活函数,一般会选用tanh函数,b为偏置。t时刻的输出为o(t)=V*h(t)+c,最终模型输出为y(t)=σ(o(t)),对于分类任务一般最后的激活函数会选用softmax函数。将RNN的结构按时间序列展开后如图2所示。

基于循环神经网络的手写数字图像识别方法

图2 按时间序列展开的RNN模型结构

在前向传播过程中用tanh作为激活函数而不是sigmoid函数,因为sigmoid函数的导数取值范围在(0,0.25],tanh函数的导数的取值范围是(0,1],二者的导数都不大于1,这会导致在接下来的反向传播算法的累乘过程中随时间序列的不断深入,累成结果不断减小,导致梯度越来越接近0,即出现“梯度消失”现象。但相比较而言,tanh函数比sigmoid函数梯度消失过程慢,故通常选用tanh函数作为激活函数。此外,sigmoid函数输出不是零中心对称,其输出均大于0,这就使得输出不是0均值,称为偏移现象。偏移现象将导致后一层的神经元将上一层输出的非0均值的信号作为输入。tanh函数具有关于原点对称输入和中心对称输出的特点,网络会收敛地更好。

(二)反向传播算法

RNN的反向传播与普通神经网络区别不大。BPTT(back-propagation through time)算法是常用的训练RNN的方法。BPTT的中心思想是沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。BPTT的本质还是BP算法,只不过RNN处理的是时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。

对于RNN,由于在序列的每个位置都有损失函数,因此最终的损失为: 

基于循环神经网络的手写数字图像识别方法

可以得到V、W、U的偏导:

基于循环神经网络的手写数字图像识别方法

对于连乘部分引入tanh激活函数后可以表示为:

基于循环神经网络的手写数字图像识别方法

三、RNN模型的演变

长短时记忆神经网络LSTM( Long ShortMemory Network)是一种变种的RNN,它是一种将以往学习的结果应用到当前学习的模型。LSTM模型在RNN的基础上引入细胞状态,根据细胞状态可以决定哪些状态应该保留下来,哪些状态应该被遗忘。LSTM可以解决RNN无法适当处理的远距离依赖问题,即在一定程度上可以解决梯度消失的问题。
LSTM 中最主要的是三个门控单元:输入门、遗忘门、输出门。输入门控制网络的输入;遗忘门的作用是决定哪些信息需要被记住,哪些需要被遗忘。遗忘门是LSTM的核心,输出门控制网络的输出。LSTM的结构如图3所示。

基于循环神经网络的手写数字图像识别方法

图3 LSTM模型基本结构

其中,三个门控单元结构细节如图4所示,LSTM模型与两个隐藏状态ht和Ct。LSTM模型参数量远远多于简单RNN模型,根据图示可以看出LSTM包含三个输入:上时刻的单元状态、上时刻LSTM的输出以及当前时刻的输入。

基于循环神经网络的手写数字图像识别方法

图4 LSTM模型门控单元结构

遗忘门输出为ft=σ(Wfh(t-1)+Ufx(t)+bf),σ为sigmiod激活函数。输入门输出为i(t)=σ(Wih(t-1)+Uix(t)+bi),a(t)=tanh(Wah(t-1)+Uax(t)+ba)。输出门之前首先要看一看LSTM的细胞状态,经过遗忘门和输入门的结果都会作用在细胞状态C(t),C(t)= C(t-1)⊙f(t)+i(t)⊙a(t)。输出门的数学表达式为o(t)=σ(Woh(t-1)+Uox(t)+bo),h(t)= o(t)⊙tanh(C(t))。最后更新当前序列索引预测输出y(t)=σ(Vh(t)+c)。LSTM模型的反向传播与RNN相同,都是借助梯度下降来训练模型。

四、基于LSTM模型识别手写数字图像

(一)数据来源

采用公开数据集Mnist数据集,在Pytorch环境中完成LSTM模型的训练与测试。数据集包含训练集60000张手写数字图像以及相应标签数据,测试集10000张图像及标签。如图5所示为训练集中随机抽取的两张手写数字图像与标签。

基于循环神经网络的手写数字图像识别方法

图5 数据集中手写数字“2”和“5”

(二)模型搭建及结果分许

选择输入图像尺寸为28×28,分批训练,batc_size设为64,隐藏层神经元个数设置为64,隐藏层数默认为1,初始学习率设为0.01,经尝试后epoch设为1,每训练50次保留一次Loss。优化器选用Adam,损失函数选择交叉熵函数。随着优化过程,训练损失函数逐渐减小(如图6所示),测试准确度不断上升(如图7所示),最终趋于平稳,可以对应于图8所示的每经过50次训练所得的损失函数值及正确率,可见最终正确率稳定于95%。

基于循环神经网络的手写数字图像识别方法  

图6 LSTM模型训练的损失函数曲线

基于循环神经网络的手写数字图像识别方法

图7 LSTM模型训练的正确率变化曲线

图8 LSTM模型训练过程

import torchimport torch.nn as nnimport torch.utils.data as Datafrom torch.autograd import Variableimport torchvision.datasets as dsetsimport matplotlib.pyplot as pltimport torchvision.transforms as transformsEPOCH = 1BATCH_SIZE = 64LR = 0.01 # 学习率DOWNLOAD_MNIST = False #已下载好数据集,就设置为False,否则为TRUETIME_STEP=28 INPUT_SIZE=28    #输入图像尺寸loss_list = []iteration_list = []accuracy_list = []## Mnist digits datasetif not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):    #若没有 mnist 路径 or mnist 是空的 DOWNLOAD_MNIST = Truetrain_data = dsets.MNIST( root='./mnist/',    train=True,                           # training data    transform=transforms.ToTensor(),  download=DOWNLOAD_MNIST,)## plot one exampleprint(train_data.train_data.size()) # (60000, 28, 28)print(train_data.train_labels.size()) # (60000)plt.imshow(train_data.train_data[0].numpy(), cmap='gray')plt.title('%i' % train_data.train_labels[0])plt.show()##training loadertrain_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)## pick 2000 samples to speed up testingtest_data=dsets.MNIST(root='./mnist/',train=False,transform=transforms.ToTensor())test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)test_y = test_data.test_labels.numpy()[:2000]class RNN(nn.Module): def __init__(self): super(RNN, self).__init__() self.rnn = nn.LSTM( input_size=INPUT_SIZE, hidden_size=64, num_layers=1, batch_first=True ) self.out=nn.Linear(64,10) def forward(self,x): r_out,(h_n,h_c)=self.rnn(x,None) out=self.out(r_out[:,-1,:]) #数据格式为[batch,time_step,input],因此输出参考的是最后时刻的数据 return outrnn=RNN()print(rnn) # net architectureoptimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parametersloss_func = nn.CrossEntropyLoss() # the target label is not one-hottedfor epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader b_x=Variable(x.view(-1,28,28)) b_y=Variable(y) output = rnn(b_x) # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 50 == 0: test_output = rnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() accuracy =float((pred_y==test_y).astype(int).sum())/float(test_y.size) print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) iteration_list.append(step+50) loss_list.append(loss.data.numpy()) accuracy_list.append(accuracy)##visualization lossplt.plot(iteration_list,loss_list)plt.xlabel("Number of iteration")plt.ylabel("Loss")plt.title("LSTM:Loss vs Number of iteration")plt.show()##visualization accuracyplt.plot(iteration_list,accuracy_list,color="red")plt.xlabel("Number of iteration")plt.ylabel("Accuracy")plt.title("LSTM:Accuracy vs Number of iteration")plt.savefig('graph.png')plt.show()## print 10 predictions from test datatest_output = rnn(test_x[:10].view(-1,28,28))pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()print(pred_y, 'prediction number')print(test_y[:10], 'real number')



statgarden稿jinjian626@163.com

以上是关于基于循环神经网络的手写数字图像识别方法的主要内容,如果未能解决你的问题,请参考以下文章

手写数字识别——基于全连接层和MNIST数据集

手写数字识别基于matlab GUI欧拉数和二维矩阵相关系数手写数字识别含Matlab源码 1896期

手写数字识别基于matlab GUI BP神经网络单个或连续手写数字识别系统含Matlab源码 2296期

手写数字识别基于matlab GUI BP神经网络单个或连续手写数字识别系统含Matlab源码 2296期

使用循环神经网络做手写数字识别

手写数字识别基于matlab GUI BP神经网络手写数字识别系统含Matlab源码 1639期