循环神经网络系列基于LSTM的MNIST手写体识别

Posted 月来客栈

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了循环神经网络系列基于LSTM的MNIST手写体识别相关的知识,希望对你有一定的参考价值。


我们知道循环神经网络是用来处理包含序列化数据的相关问题,所有若要利用循环神经网络来解决某类问题,那么首先要做的就是将原始数据集序列化,然后处理成某个深度学习框架所接受的数据输入格式(比如Tensorflow).


1.数据预处理

我们知道MNIST数据集中的每张图片形状都是​​[28*28]​​​的,那么如何将其序列化呢?想要知道怎么序列化,那还得从LSTM接受怎样的输入说起。由于​​前文​​我们说到,一个LSTM单元可以按时间维度进行展开处理,那么对于一个黑白图片来说该怎么展开呢?并且展开后必须得有前后的序列关系。 最直接的想法当然就是每个像素点当成一个部分,直接展开成784个LSTM单元。这种做法理论上当然没有什么问题,只是显得略微有点粗暴了;那我们就稍微缓和一点,按行或者按列来分割成28行(列),然后将这28个部分看成是序列。如下图所示:

循环神经网络系列(四)基于LSTM的MNIST手写体识别_序列化

所以,对于整个数据集来说:我们第一步要做的就是将其​​reshape​​​成​​[batchsize,high,width]​​​的这种形式;然后第二步就是将其按行分割成28个部分,变成​​[timestep,batchsize,dim]​​​。以​​batchsize = 4​​为例,可以画出如下示意图:

循环神经网络系列(四)基于LSTM的MNIST手写体识别_序列化_02

温馨提示:4种颜色分别表示4张图片

因此,这部分对应代码就是:

    x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name=input-x)
y = tf.placeholder(dtype=tf.int32, shape=[None], name=input-y)
x_reshape = tf.reshape(x, shape=[-1, DIM, DIM], name=reshape-x)
x_tranpose = tf.transpose(x_reshape, perm=[1, 0, 2], name=transpose-x)

其中第4行就表示按行分割成28个部分,把所有的第i行都放在一起;

对于得到的这种形式的数据,我们在喂给展开后的LSTM时是长下面这个样的:

循环神经网络系列(四)基于LSTM的MNIST手写体识别_数据集_03

2.搭建网络

从​​上文​​的介绍可知,经过LSTM处理后,输出结果的格式是:​​[timesteps,batchsize,outputsize]​​。又由于我们做的仅仅是分类任务,所以我们接下来就取最后一个​​timesteps​​的输出作为整个LSTM的输出即可。同时,作为分类任务,我们需要得到每个类别的预测概率,因此还需要再LSTM的输出结果后加上一个​​softmax​​层,到此网络结构的搭建就完了。

循环神经网络系列(四)基于LSTM的MNIST手写体识别_数据集_04

这部分对应代码就是:

def lstm(inputs):
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=OUTPUT_SIZE)
h0 = cell.zero_state(batch_size=tf.shape(inputs)[1], dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs=inputs, initial_state=h0, time_major=True)
return outputs[-1]

y_ = lstm(x_tranpose)
with tf.name_scope(weighted-softmax):
weights = tf.Variable(tf.truncated_normal(shape=[OUTPUT_SIZE, OUTPUT_SIZE], stddev=0.1), dtype=tf.float32)
bias = tf.Variable(tf.constant(0, shape=[OUTPUT_SIZE], dtype=tf.float32))
logits = tf.nn.xw_plus_b(y_, weights, bias, name=softmax)

​源码戳此处​

更多内容欢迎扫码关注公众号月来客栈!

循环神经网络系列(四)基于LSTM的MNIST手写体识别_数据集_05



以上是关于循环神经网络系列基于LSTM的MNIST手写体识别的主要内容,如果未能解决你的问题,请参考以下文章

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

数据挖掘入门系列教程之使用神经网络(基于pybrain)识别数字手写集MNIST

图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)

caffe lstm训练mnist手写数字

caffe lstm训练mnist手写数字

tensorflow笔记之MNIST手写识别系列一