Keras 中的 LSTM 实现是如何工作的

Posted

技术标签:

【中文标题】Keras 中的 LSTM 实现是如何工作的【英文标题】:How does the LSTM implementation in Keras work 【发布时间】:2018-07-21 05:44:26 【问题描述】:

我正在查看 recurrent.pyLSTMCell 类的代码 (https://github.com/keras-team/keras/blob/master/keras/layers/recurrent.py)

该类是否计算单个时间步的隐藏和携带状态?

我在哪里可以找到处理展开网络的代码,即从一个时间步到另一个时间步?

我正在尝试计算单个示例在每个时间步的每个门的输出。到目前为止,我可以从经过训练的网络中提取权重、偏差并按照第 1828 行到 1858 行的代码计算激活值。特别是:

i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                          self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                          self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                          self.recurrent_kernel_o))

我的输入有形状:输入(seq_length,nb_dim)。所以要正确计算每个门的输出,我应该这样做:

for step in range(seq_length):
  input_step = input[step, :]
  x_i = np.dot(input_step, kernel_i) + bias_i
  i = recurrent_activation(x_i + np.dot(h_tm1_i, recurrent_kernel_i)
  <<< repeat for other gates >>>
  <<<compute cell hidden state/carry state>>>

【问题讨论】:

【参考方案1】:

我在哪里可以找到处理展开网络的代码,即从一个时间步到另一个时间步?

这个逻辑是由keras.backend.rnn函数(recurrent.py)完成的:

last_output, outputs, states = K.rnn(step,
                                     inputs,
                                     initial_state,
                                     constants=constants,
                                     go_backwards=self.go_backwards,
                                     mask=mask,
                                     unroll=self.unroll,
                                     input_length=timesteps)

step 基本上是一个单元格的调用...

def step(inputs, states):
  return self.cell.call(inputs, states, **kwargs)

... 在 LSTM 单元的情况下,如您的问题中所述,计算 ifco 门,并评估它们的输出和状态张量。

如果您使用的是 tensorflow 后端,您可以在 keras/backend/tensorflow_backend.py 中找到迭代输入序列的实际循环。

【讨论】:

谢谢。正是我想要的。还有一件事,我检查了tensorflow_backend.py (github.com/keras-team/keras/blob/master/keras/backend/…)。与展开的 rnn(预测时)相关的代码是从第 2662 行到第 2670 行。对吗?您知道哪个导入的文件包含函数step_function() 的代码吗?谢谢

以上是关于Keras 中的 LSTM 实现是如何工作的的主要内容,如果未能解决你的问题,请参考以下文章

Keras 中的 CuDNNLSTM 和 LSTM 有啥区别?

如何重塑文本数据以适合 keras 中的 LSTM 模型

如何在 Keras 中解释 LSTM 层中的权重 [关闭]

Keras 中的多对一和多对多 LSTM 示例

Keras中的LSTM

Keras 的 LSTM 中的时间步长是多少?