Tensorflow.keras:RNN 对 Mnist 进行分类

Posted

技术标签:

【中文标题】Tensorflow.keras:RNN 对 Mnist 进行分类【英文标题】:Tensorflow.keras: RNN to classify Mnist 【发布时间】:2020-11-13 06:04:08 【问题描述】:

我试图通过构建一个简单的数字分类器来理解 tensorflow.keras.layers.SimpleRNN。 Mnist 数据集的数字大小为 28X28。所以主要思想是在时间 t 内呈现图像的每一行。我在一些博客中出现过这个想法,例如,this one,它展示了这张图片:

所以我的RNN是这样的:

units=128
self.model = Sequential()        
self.model.add(layers.SimpleRNN(128, input_shape=(28,28)))
self.model.add(Dense(self.output_size, activation='softmax'))

我知道 RNN 是使用以下等式定义的:

参数:

W=w_hh,w_xh 和 V=v。

输入向量:x_t。

更新方程式:

h_t=f(w_hh h_t-1+w_xh x_t)。

y = v h_t.

问题:

    “units=128”的确切定义是什么?是w_hh,w_xh的神经元个数吗?有什么地方可以找到这些信息吗?

    如果我运行self.model.summary()

我明白了

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 128)               20096     
_________________________________________________________________
dense_35 (Dense)             (None, 10)                1290      
=================================================================
Total params: 21,386
Trainable params: 21,386
Non-trainable params: 0
_________________________

如何从单元数到这些参数数“20096”和“1290”?

    在本例中,序列的大小始终相同。但是,我正在处理文本,序列的大小可变。那么,究竟 input_shape=(28,28) 是什么意思呢?我在任何地方都找不到此信息。

【问题讨论】:

【参考方案1】:

    Units 是神经元的数量,它是该层输出的维度。可以在documentation 找到此信息。

    参数数量取决于层输入和单元数量。对于 SimpleRNN 层,这是 128 * 128 + 128 * 28 + 128 = 20096(参见 this answer)。对于密集层,这是 128 * 10 + 10 = 1290。添加这 10 和 128 是因为层中的偏置权重,默认情况下是打开的。

    input_shape = (28, 28) 表示您的网络将处理大小为 28x28 数据点的输入。由于第一个维度是批量维度,它将处理 28 个长度为 28 的向量(如图所示)。可变长度的输入通常被拆分以适应给定的 input_shape。如果输入小于 input_shape,则可以应用填充以使其适合。

【讨论】:

构建 RNN 的动机之一是可变输入大小的想法。所以,最后,输入(理论上)根本没有变化。我知道还有其他优点(和扩展),但在传统的神经网络中也可以使用填充。对吗? 没错。实例化输入层后,输入大小是固定的,但由于 RNN 具有内部状态,因此它可以很好地处理任意大小的顺序数据。是的,填充可用于传统网络。这是一种增加输入数据大小以使其适合网络操作的方法,无需外推。

以上是关于Tensorflow.keras:RNN 对 Mnist 进行分类的主要内容,如果未能解决你的问题,请参考以下文章

使用 Tensorflow.keras 组织项目。应该是 tf.keras.Model 的一个子类吗?

开源!基于TensorFlow/Keras/PyTorch实现对自然场景的文字检测及端到端的OCR中文文字识别

使用 optuna 进行优化时的 TensorFlow / keras 问题

自定义 TensorFlow Keras 优化器

TensorFlow Keras 层中的重新排序轴

Tensorflow+Keras用Tensorflow.keras的方法替代keras.layers.merge