如何使用 Tensorflow v1.1 seq2seq.dynamic_decode?

Posted

技术标签:

【中文标题】如何使用 Tensorflow v1.1 seq2seq.dynamic_decode?【英文标题】:How to use Tensorflow v1.1 seq2seq.dynamic_decode? 【发布时间】:2017-11-12 23:15:19 【问题描述】:

我正在尝试使用来自 Tensorflow 的 seq2seq.dynamic_decode 来构建序列到序列模型。我已经完成了编码器部分。 我对解码器感到困惑,因为decoder_outputs 似乎返回了[batch_size x sequence_length x embedding_size],但我需要实际的单词索引来正确计算我的损失[batch_size x sequence_length]。 我想知道我的某个形状输入是否不正确,或者我只是忘记了什么。 解码器和编码器单元为rnn.BasicLSTMCell()

# Variables
cell_size = 100
decoder_vocabulary_size = 7
batch_size = 2
decoder_max_sentence_len = 7
# Part of the encoder
_, encoder_state = tf.nn.dynamic_rnn(
          cell=encoder_cell,
          inputs=features,
          sequence_length=encoder_sequence_lengths,
          dtype=tf.float32)
# ---- END Encoder ---- #
# ---- Decoder ---- #
# decoder_sequence_lengths = _sequence_length(features)
embedding = tf.get_variable(
     "decoder_embedding", [decoder_vocabulary_size, cell_size])
helper = seq2seq.GreedyEmbeddingHelper(
     embedding=embedding,
     start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
     end_token=END_SYMBOL)
decoder = seq2seq.BasicDecoder(
     cell=decoder_cell,
     helper=helper,
     initial_state=encoder_state)
decoder_outputs, _ = seq2seq.dynamic_decode(
     decoder=decoder,
     output_time_major=False,
     impute_finished=True,
     maximum_iterations=self.decoder_max_sentence_len)
# I need labels (decoder_outputs) to be indices
losses = nn_ops.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
loss = tf.reduce_mean(losses)

【问题讨论】:

【参考方案1】:

我发现解决办法是:

from tensorflow.python.layers.core import Dense
decoder = seq2seq.BasicDecoder(
      cell=decoder_cell,
      helper=helper,
      initial_state=encoder_state,
      output_layer=Dense(decoder_vocabulary_size))
...
logits = decoder_outputs[0]

您必须指定一个密集层以从 cell_size 投影到词汇大小。

【讨论】:

以上是关于如何使用 Tensorflow v1.1 seq2seq.dynamic_decode?的主要内容,如果未能解决你的问题,请参考以下文章

如何在没有嵌入的情况下使用 tensorflow seq2seq?

使用 seq2seq API(ver 1.1 及更高版本)的 Tensorflow 序列到序列模型

使用Tensorflow搭建一个简单的Seq2Seq翻译模型

Tensorflow动态seq2seq使用总结(r1.3)

TensorFlow 中 RNN 和 Seq2Seq 模型的 API 参考

如何将单词映射到数字以输入到 Tensorflow 神经网络