如何在keras中包装张量流RNNCell?

Posted

技术标签:

【中文标题】如何在keras中包装张量流RNNCell?【英文标题】:How to wrap a tensorflow RNNCell in keras? 【发布时间】:2019-05-14 01:33:34 【问题描述】:

我想在 keras 层中实现自定义 LSTM 单元。实际上这个实现存在于 tensorflow 中,所以我想知道是否可以将其包装为 keras 层并在模型中调用它。

我发现官方documentation 太简单了,看不到如何构建自定义 RNN 层。 here 和 here 也有类似的问题,但似乎没有得到解决。

提前感谢您的帮助!

【问题讨论】:

【参考方案1】:

现在 tensorflow 的文档可能在问题发布后有所改进。

您可能需要查看this guide 或this SO answer 以供参考。

【讨论】:

【参考方案2】:

根据我的理解,您应该能够在类层的 init() 中初始化单元格,然后在调用方法中使用您的输入引用它。

例如:

class MySimpleLayer(Layer):
  def __init__(self, lstm_size):
    super(MySimpleLayer, self).__init__()
    self.lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)

  def call(self, batch, state):
    return self.lstm(batch, state)

layer = MySimpleLayer(lstm_size)
logits = layer(batch, state)

这个实现是最基本的,所以你可能需要研究 build() 和 compute_output_shape() 方法来处理更复杂的用例。

【讨论】:

对不起,call 这样的定义与Layer 不匹配;我得到 TypeError: call() missing 1 required positional argument: 'states' Call() 肯定与 Layer 类一起使用,正如 here 所指定的那样。对我来说,这看起来像是一个实施错误。尝试将 [batch, state] 作为单个列表输入传递给 call()。

以上是关于如何在keras中包装张量流RNNCell?的主要内容,如果未能解决你的问题,请参考以下文章

如何加载张量流模型

如何在张量流中对张量进行子集化?

如何在张量流中加载本地图像?

如何在张量流中将 TextVectorization 保存到磁盘?

如何在张量流 TakeDataset 上使用 file_paths?

如何测试我在真实图片上训练过的张量流模型?