TensorFlow:从 RNN 获取所有状态
Posted
技术标签:
【中文标题】TensorFlow:从 RNN 获取所有状态【英文标题】:TensorFlow: getting all states from a RNN 【发布时间】:2017-02-04 14:07:17 【问题描述】:如何从 TensorFlow 中的 tf.nn.rnn()
或 tf.nn.dynamic_rnn()
获取所有隐藏状态? API 只给了我最终的状态。
第一种选择是在构建直接在 RNNCell 上运行的模型时编写一个循环。但是,时间步数对我来说不是固定的,取决于传入的批次。
一些选项是使用 GRU 或编写我自己的 RNNCell 将状态连接到输出。前一种选择不够通用,而后者听起来太老套了。
另一种选择是执行the answers in this question 之类的操作,从 RNN 获取所有变量。但是,我不确定如何在这里以标准方式将隐藏状态与其他变量分开。
有没有一种很好的方法可以在仍然使用库提供的 RNN API 的同时从 RNN 中获取所有隐藏状态?
【问题讨论】:
我创建了一个 PR here,它可能会帮助您处理简单的案例 【参考方案1】:tf.nn.dynamic_rnn(also tf.nn.static_rnn) 有两个返回值; “输出”、“状态” (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
正如你所说,“状态”是RNN的最终状态,但“输出”都是RNN的隐藏状态(形状为[batch_size,max_time,cell.output_size])
您可以使用“输出”作为 RNN 的隐藏状态,因为在大多数库提供的 RNNCell 中,“输出”和“状态”是相同的。 (LSTMCell 除外)
基础https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347 GRU https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441【讨论】:
撇开这是特定于 GRU 的,如果您有多个层,这对您没有帮助,例如,如果您将 GRUCell 包装在 MultiRNNCell 中。您的输出将仅包含来自最后一层的状态。【参考方案2】:我已经创建了一个 PR here,它可能会帮助您处理简单的案例
让我简要解释一下我的实现,以便您可以根据需要编写自己的版本。主要是_time_step
函数的修改:
def _time_step(time, output_ta_t, state, *args):
参数保持不变,只是传入了一个额外的*args
。但是为什么args
?那是因为我想支持 tensorflow 的习惯行为。您只需忽略args
参数即可返回最终状态:
if states_ta is not None:
# If you want to return all states, set `args` to be `states_ta`
loop_vars = (time, output_ta, state, states_ta)
else:
# If you want the final state only, ignore `args`
loop_vars = (time, output_ta, state)
如何使用?
if args:
args = tuple(
ta.write(time, out) for ta, out in zip(args[0], [new_state])
)
其实这只是对以下(原始)代码的修改:
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output)
)
现在args
应该包含您想要的所有状态。
完成以上所有工作后,您可以使用以下代码获取状态(或最终状态):
_, output_final_ta, *state_info = control_flow_ops.while_loop( ...
和
if states_ta is not None:
final_state, states_final_ta = state_info
else:
final_state, states_final_ta = state_info[0], None
虽然我没有在复杂的情况下测试过它,但它应该在“简单”条件下工作(here's我的测试用例)
【讨论】:
感谢您花时间撰写答案。在回答您的第一句话时,最好不要在 Stack Overflow 上有重复的信息。一旦您拥有 75 名声望,您就可以将一个问题标记为另一个问题的重复(尽管我可能错了,也许您现在可以这样做)。如果问题不相同,最好根据问题的需要调整每个答案。 感谢您的评论!我已经发现这两个问题之间存在一些差异,所以现在我听从了您的建议并定制了每个答案:) 我实际解决这个问题的方法是创建一个包装单元(如 MultiRNNCell),它输出与输出连接的状态。之后只需要进行拆分即可将输出与隐藏状态分开。以上是关于TensorFlow:从 RNN 获取所有状态的主要内容,如果未能解决你的问题,请参考以下文章