LSTMCell中num_units参数解释

Posted 琥珀彩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了LSTMCell中num_units参数解释相关的知识,希望对你有一定的参考价值。

前言
关于LSTM原理: http://colah.github.io/posts/2015-08-Understanding-LSTMs/
关于LSTM原理(译文):https://blog.csdn.net/Jerr__y/article/details/58598296
关于Tensorflow+LSTM的使用:https://www.knowledgemapper.com/knowmap/knowbook/jasdeepchhabra94@gmail.comUnderstandingLSTMinTensorflow(MNISTdataset)
关于Tensorflow+LSTM的使用(译文):https://yq.aliyun.com/articles/202939
 

正文
本文只是介绍tensorflow中的BasicLSTMCell中num_units。
在使用Tensorflow跑LSTM的试验中, 有个num_units的参数,这个参数是什么意思呢?

先总结一下,num_units这个参数的大小就是LSTM输出结果的维度。例如num_units=128, 那么LSTM网络最后输出就是一个128维的向量。

我们先换个角度举个例子,最后再用公式来说明。

假设在我们的训练数据中,每一个样本 x 是 28*28 维的一个矩阵,那么将这个样本的每一行当成一个输入,通过28个时间步骤展开LSTM,在每一个LSTM单元,我们输入一行维度为28的向量,如下图所示。


那么,对每一个LSTM单元,参数 num_units=128 的话,就是每一个单元的输出为 128*1 的向量,在展开的网络维度来看,如下图所示,对于每一个输入28维的向量,LSTM单元都把它映射到128维的维度, 在下一个LSTM单元时,LSTM会接收上一个128维的输出,和新的28维的输入,处理之后再映射成一个新的128维的向量输出,就这么一直处理下去,知道网络中最后一个LSTM单元,输出一个128维的向量。

 

 


从LSTM的公式的角度看是什么原理呢?我们先看一下LSTM的结构和公式:


参数 num_units=128 的话,

 

  1. 对于公式 (1) ,h=128∗1维, x = 28 ∗ 1 维,[h,x]便等于156 ∗ 1 维,W = 128 ∗ 156维,所以 W ∗ [ h , x ] = 128 ∗ 156 ∗ 156 ∗ 1 = 128 ∗ 1 , b = 128 ∗ 1维, 所以 f = 128 ∗ 1 + 128 ∗ 1 = 128 ∗ 1 维;
  2. 对于公式 (2) 和 (3),同上可分析得 i = 128 ∗ 1维,= 128 ∗ 1维;
  3. 对于公式 (4) ,f ( t ) = 128 ∗ 1 , C ( t − 1 ) = 128 ∗ 1 , f ( t ) . ∗ C ( t − 1 ) = 128 ∗ 1. ∗ 128 ∗ 1 = 128 ∗ 1, 同理可得 C ( t ) = 128 ∗ 1 维;
  4. 对于公式 (5) 和 (6) , 同理可得 O = 128 ∗ 1  维, h = O . ∗ t a n h ( C ) = 128 ∗ 1维。

所以最后LSTM单元输出的h就是 128 ∗ 1 128*1128∗1 的向量。

以上就是 num_units 参数的含义。

如有错误请指出谢谢

参考链接:
https://stackoverflow.com/questions/37901047/what-is-num-units-in-tensorflow-basiclstmcell
https://www.zhihu.com/question/64470274


————————————————
版权声明:本文为CSDN博主「notHeadache」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/notHeadache/article/details/81164264

以上是关于LSTMCell中num_units参数解释的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 的 LSTMCell 究竟是如何运作的?

Tensorflow中循环神经网络及其Wrappers

关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题

Tensorflow RNN LSTM 输出解释

张量流 BasicLSTMCell 中的 num_units 是啥?

深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.