如何将最后一个输出 y(t-1) 作为输入以在 tensorflow RNN 中生成 y(t)?
Posted
技术标签:
【中文标题】如何将最后一个输出 y(t-1) 作为输入以在 tensorflow RNN 中生成 y(t)?【英文标题】:How can I feed last output y(t-1) as input for generating y(t) in tensorflow RNN? 【发布时间】:2016-05-10 19:06:37 【问题描述】:我想在 Tensorflow 中设计一个单层 RNN,让最后一个输出 (y(t-1))
参与更新隐藏状态。
h(t) = tanh(W_ih * x(t) + W_hh * h(t) + **W_ohy(t - 1)**)
y(t) = W_ho*h(t)
如何将最后一个输入 y(t - 1)
作为更新隐藏状态的输入?
【问题讨论】:
目前,我正在研究这个看起来很有希望的教程:github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/… 相关帖子:***.com/questions/39681026/… 【参考方案1】:我将您所描述的称为“自回归 RNN”。这是一个(不完整的)代码 sn-p,它显示了如何使用 tf.nn.raw_rnn
创建一个:
import tensorflow as tf
LSTM_SIZE = 128
BATCH_SIZE = 64
HORIZON = 10
lstm_cell = tf.nn.rnn_cell.LSTMCell(LSTM_SIZE, use_peepholes=True)
class RnnLoop:
def __init__(self, initial_state, cell):
self.initial_state = initial_state
self.cell = cell
def __call__(self, time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
initial_input = tf.fill([BATCH_SIZE, LSTM_SIZE], 0.0)
next_input = initial_input
next_cell_state = self.initial_state
else:
next_input = cell_output
next_cell_state = cell_state
elements_finished = (time >= HORIZON)
next_loop_state = None
return elements_finished, next_input, next_cell_state, emit_output, next_loop_state
rnn_loop = RnnLoop(initial_state=initial_state_tensor, cell=lstm_cell)
rnn_outputs_tensor_array, _, _ = tf.nn.raw_rnn(lstm_cell, rnn_loop)
rnn_outputs_tensor = rnn_outputs_tensor_array.stack()
这里我们用一些向量initial_state_tensor
初始化LSTM的内部状态,并在t=0
处输入零数组作为输入。之后,当前时间步的输出就是下一个时间步的输入。
【讨论】:
你将如何在 Keras @Alexey Petrenko 中实现这个? 不确定 Keras,但这在 PyTorch 中绝对是微不足道的! :D【参考方案2】:一种可能性是使用我在article 中找到的tf.nn.raw_rnn
。检查我对此related post 的回答。
【讨论】:
【参考方案3】:y(t-1) 是最后一个输入还是输出?在这两种情况下,它都不是直接适合 TensorFlow RNN 单元抽象。如果您的 RNN 很简单,您可以自己编写循环,然后您就可以完全控制。我会使用的另一种方法是预处理您的 RNN 输入,例如,执行以下操作:
processed_input[t] = tf.concat(input[t], input[t-1])
然后用 processes_input 调用 RNN 单元并在那里拆分。
【讨论】:
它自己使用最后生成的输出。如何为简单的 RNN 编写循环?在幕后优化的方式是好的。以上是关于如何将最后一个输出 y(t-1) 作为输入以在 tensorflow RNN 中生成 y(t)?的主要内容,如果未能解决你的问题,请参考以下文章