循环神经网络系列基于LSTM的MNIST手写体识别
Posted 月来客栈
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了循环神经网络系列基于LSTM的MNIST手写体识别相关的知识,希望对你有一定的参考价值。
我们知道循环神经网络是用来处理包含序列化数据的相关问题,所有若要利用循环神经网络来解决某类问题,那么首先要做的就是将原始数据集序列化,然后处理成某个深度学习框架所接受的数据输入格式(比如Tensorflow).
1.数据预处理
我们知道MNIST数据集中的每张图片形状都是[28*28]
的,那么如何将其序列化呢?想要知道怎么序列化,那还得从LSTM接受怎样的输入说起。由于前文我们说到,一个LSTM单元可以按时间维度进行展开处理,那么对于一个黑白图片来说该怎么展开呢?并且展开后必须得有前后的序列关系。 最直接的想法当然就是每个像素点当成一个部分,直接展开成784个LSTM单元。这种做法理论上当然没有什么问题,只是显得略微有点粗暴了;那我们就稍微缓和一点,按行或者按列来分割成28行(列),然后将这28个部分看成是序列。如下图所示:
所以,对于整个数据集来说:我们第一步要做的就是将其reshape
成[batchsize,high,width]
的这种形式;然后第二步就是将其按行分割成28个部分,变成[timestep,batchsize,dim]
。以batchsize = 4
为例,可以画出如下示意图:
温馨提示:4种颜色分别表示4张图片
因此,这部分对应代码就是:
x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name=input-x)
y = tf.placeholder(dtype=tf.int32, shape=[None], name=input-y)
x_reshape = tf.reshape(x, shape=[-1, DIM, DIM], name=reshape-x)
x_tranpose = tf.transpose(x_reshape, perm=[1, 0, 2], name=transpose-x)
其中第4行就表示按行分割成28个部分,把所有的第i行都放在一起;
对于得到的这种形式的数据,我们在喂给展开后的LSTM时是长下面这个样的:
2.搭建网络
从上文的介绍可知,经过LSTM处理后,输出结果的格式是:[timesteps,batchsize,outputsize]
。又由于我们做的仅仅是分类任务,所以我们接下来就取最后一个timesteps
的输出作为整个LSTM的输出即可。同时,作为分类任务,我们需要得到每个类别的预测概率,因此还需要再LSTM的输出结果后加上一个softmax
层,到此网络结构的搭建就完了。
这部分对应代码就是:
def lstm(inputs):
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=OUTPUT_SIZE)
h0 = cell.zero_state(batch_size=tf.shape(inputs)[1], dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs=inputs, initial_state=h0, time_major=True)
return outputs[-1]
y_ = lstm(x_tranpose)
with tf.name_scope(weighted-softmax):
weights = tf.Variable(tf.truncated_normal(shape=[OUTPUT_SIZE, OUTPUT_SIZE], stddev=0.1), dtype=tf.float32)
bias = tf.Variable(tf.constant(0, shape=[OUTPUT_SIZE], dtype=tf.float32))
logits = tf.nn.xw_plus_b(y_, weights, bias, name=softmax)
源码戳此处
更多内容欢迎扫码关注公众号月来客栈!
以上是关于循环神经网络系列基于LSTM的MNIST手写体识别的主要内容,如果未能解决你的问题,请参考以下文章
图像分类基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)
数据挖掘入门系列教程之使用神经网络(基于pybrain)识别数字手写集MNIST