LSTM参数详解(其余RNN类似)

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了LSTM参数详解(其余RNN类似)相关的知识,希望对你有一定的参考价值。

参考技术A

输入数据 input : (seq_len, batch_size, input_size)
LSTM (input_size, hidden_size, num_layers = 1, bidirectional = False)
其中在时间步 t 的hidden_state ht 和cell_state ct 的shape均为
(num_layers * num_direction, batch_size,hidden_size)
输出向量 output : (seq_leng, batch_size, num_directions * hidden_size)

调用方法 output,(hn,cn) = lstm(input,(h0,c0)) #h0,c0如果省略即为0向量

output.size() == (seq_len, batch_size, hidden_size)
hn.size() == (1, batch_size, hidden_size)
hn 就是 output 的seq_len维度最后一个index的元素。

output.size() == (seq_len, batch_size, 2 * hidden_size)
hn.size() == (2, batch_size, hidden_size)
那么这时 output 在seq_len维度最后一个index的元素其实就是 hn[0] hn[1] 的concatenation
其中 hn[0] LSTM 从左向右编码句子的最后一个hidden_state,对应最后一个token;
然而 hn[1] LSTM 从右向左编码句子的最后一个hidden_state,对应第一个token。
output 如果按照第三维度均分为两份就可以得到 output_forward output_backward
他们的size() 都 == (seq_len, batch_size, hidden_size)
其中
output_forward[-1] == hn[0] 也就是从左向右编码对应最后一个token的hidden_state ;
output_backward[0] == hn[1] 也就是从右向左编码对应第一个token的hidden_state ;

以上是关于LSTM参数详解(其余RNN类似)的主要内容,如果未能解决你的问题,请参考以下文章

(数据科学学习手札39)RNN与LSTM基础内容详解

[Pytorch系列-53]:循环神经网络 - torch.nn.LSTM()参数详解

pytorch nn.LSTM()参数详解

[Pytorch系列-51]:循环神经网络RNN - torch.nn.RNN类的参数详解与代码示例

TensorFlow 中 LSTM-RNN 参数的占位符

什么是使用Keras的RNN Layer的return_state输出