TensorFlow 中 RNN 和 Seq2Seq 模型的 API 参考

Posted

技术标签:

【中文标题】TensorFlow 中 RNN 和 Seq2Seq 模型的 API 参考【英文标题】:API Reference for RNN and Seq2Seq models in tensorflow 【发布时间】:2016-09-06 15:53:25 【问题描述】:

我在哪里可以找到指定 RNN 和 Seq2Seq 模型中可用函数的 API 参考。

在 github 页面中提到将 rnn 和 seq2seq 移至 tf.nn

【问题讨论】:

您使用的是 Python 还是 C++ API? 【参考方案1】:

[注意:此答案已针对 r1.0 进行了更新...但解释了 legacy_seq2seq 而不是 tensorflow/tensorflow/contrib/seq2seq/]

好消息是 tensorflow 中提供的 seq2seq 模型非常复杂,包括嵌入、桶、注意力机制、一对多多任务模型等。

坏消息是 Python 代码中有很多复杂性和抽象层,据我所知,代码本身是更高级别 RNN 和 seq2seq “API”的最佳可用“文档” ...谢天谢地,代码是很好的 docstring'd。

实际上,我认为下面指出的示例和辅助函数主要用于参考以了解编码模式……并且在大多数情况下,您需要使用下面的基本函数重新实现所需的功能级Python API

这里是从上到下对 r1.0 版本的 RNN seq2seq 代码的细分:

models/tutorials/rnn/translate/translate.py

...提供main()train()decode(),可以开箱即用地将英语翻译成法语...但是您可以将此代码调整到其他数据集

models/tutorials/rnn/translate/seq2seq_model.py

...class Seq2SeqModel() 设置了一个复杂的 RNN 编码器-解码器,带有嵌入、桶、注意力机制......如果您不需要嵌入、桶或注意力,您将需要实现一个类似的类。

tensorflow/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py

...seq2seq 模型通过辅助函数的主要入口点。见model_with_buckets()embedding_attention_seq2seq()embedding_attention_decoder()attention_decoder()sequence_loss()等。 示例包括one2many_rnn_seq2seq 和没有嵌入/注意的模型也提供了类似basic_rnn_seq2seq。如果您可以将数据塞入这些函数将接受的张量中,这可能是您构建自己的模型的最佳切入点。

tensorflow/tensorflow/contrib/rnn/python/ops/core_rnn.py

...为像 static_rnn() 这样的 RNN 网络提供了一个包装器,带有一些我通常不需要的花里胡哨,所以我只使用这样的代码:

def simple_rnn(cell, inputs, dtype, score):
    with variable_scope.variable_scope(scope or "simple_RNN") as varscope1:
            if varscope1.caching_device is None:
                varscope1.set_caching_device(lambda op: op.device)

        batch_size = array_ops.shape(inputs[0])[0]
        outputs = []
        state = cell.zero_state(batch_size, dtype)            

        for time, input_t in enumerate(inputs):
           if time > 0:      
             variable_scope.get_variable_scope().reuse_variables()


           (output, state) = cell(input_t, state)

           outputs.append(output)

        return outputs, state

【讨论】:

【参考方案2】:

到目前为止,我在他们的网站上也找不到关于 rnn 函数的 API 参考资料。

不过,我相信你可以在github上看到每个函数的cmets作为函数参考。

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py

【讨论】:

【参考方案3】:

TensorFlow 的当前/主版本的 RNN 文档: https://www.tensorflow.org/versions/master/api_docs/python/nn.html#recurrent-neural-networks

特定版本 TensorFlow 的 RNN 文档: https://www.tensorflow.org/versions/r0.10/api_docs/python/nn.html#recurrent-neural-networks

为了好奇,这里有一些关于为什么最初没有 RNN 文档的注释: API docs does not list RNNs

【讨论】:

以上是关于TensorFlow 中 RNN 和 Seq2Seq 模型的 API 参考的主要内容,如果未能解决你的问题,请参考以下文章

如何用TensorFlow构建RNN

如何用TensorFlow构建RNN

Tensorflow 与 Keras 中的 RNN,tf.nn.dynamic_rnn() 的贬值

TensorFlow框架之RNN循环神经网络详解

学习Tensorflow的LSTM的RNN例子

TensorFlow 2 中的堆叠双向 RNN 令人困惑