state_is_tuple=True 时如何设置 TensorFlow RNN 状态?

Posted

技术标签:

【中文标题】state_is_tuple=True 时如何设置 TensorFlow RNN 状态?【英文标题】:How do I set TensorFlow RNN state when state_is_tuple=True? 【发布时间】:2016-12-30 22:59:40 【问题描述】:

我写了一个RNN language model using TensorFlow。该模型被实现为RNN 类。图结构内置在构造函数中,而RNN.trainRNN.test 方法运行它。

当我移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置 RNN 状态。我通过管理训练循环中的状态,通过提要字典将其传递到图中来做到这一点。

在构造函数中我这样定义 RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环如下所示

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict=self.x:x, self.y:y, self.state:state)

xy 是文档中的一批训练数据。这个想法是我在每批之后传递最新状态,除非我开始一个新文档,当我通过运行 self.reset_state 将状态归零时。

这一切都有效。现在我想更改我的 RNN 以使用推荐的 state_is_tuple=True。但是,我不知道如何通过提要字典传递更复杂的 LSTM 状态对象。另外我不知道在构造函数中将哪些参数传递给self.state = tf.placeholder(...) 行。

这里的正确策略是什么?仍然没有太多dynamic_rnn 的示例代码或文档可用。


TensorFlow 问题 2695 和 2838 似乎相关。

WILDML 上的blog post 解决了这些问题,但没有直接说明答案。

另见TensorFlow: Remember LSTM state for next batch (stateful LSTM)。

【问题讨论】:

查看rnn_cell._unpacked_staternn_cell._packed_state。这些用于rnn._dynamic_rnn_loop() 将状态作为参数张量列表传递给循环函数。 我在最新的 TensorFlow 源代码中没有看到字符串 _unpacked_state_packed_state。这些名字变了吗? 嗯。那些已被删除。取而代之的是,引入了一个新模块 tf.python.util.nest 与类似物 flattenpack_sequence_as 有没有人尝试更新他们的 TF1.0.1 代码? API 发生了显着变化。 【参考方案1】:

Tensorflow 占位符的一个问题是您只能使用 Python 列表或 Numpy 数组来提供它(我认为)。所以你不能在 LSTMStateTuple 的元组中保存运行之间的状态。

我通过将状态保存在这样的张量中解决了这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM 层中有两个组件,cell statehidden state,这就是“2”的来源。 (这篇文章很棒:https://arxiv.org/pdf/1506.00019.pdf)

在构建图形时,您解压缩并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后你以通常的方式获得新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

不应该是这样的......也许他们正在研究解决方案。

【讨论】:

如果只有一层,会变成state_placeholder = tf.placeholder(tf.float32, [2, batch_size, state_size])initial_state = np.zeros((2, batch_size, state_size))吗?【参考方案2】:

输入 RNN 状态的一种简单方法是简单地分别输入状态元组的两个组件。

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict=
    self.input: input
)

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict=
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
)

【讨论】:

以上是关于state_is_tuple=True 时如何设置 TensorFlow RNN 状态?的主要内容,如果未能解决你的问题,请参考以下文章

deep_learning_Function_rnn_cell.BasicLSTMCell

tf.nn.rnn_cell.MultiRNNCell

当 setAcceptDrops 设置设置为 True 时,如何检测 QListWidet 内部移动信号

使用 gsprint 打印时如何设置 collat​​e = true

不使用 FormsAuthentication.RedirectFromLoginPage 时如何将 Request.IsAuthenticated 设置为 true?

设置为null后如何将onItemClickListener重置为true?