如何在keras中包装张量流RNNCell?
Posted
技术标签:
【中文标题】如何在keras中包装张量流RNNCell?【英文标题】:How to wrap a tensorflow RNNCell in keras? 【发布时间】:2019-05-14 01:33:34 【问题描述】:我想在 keras 层中实现自定义 LSTM 单元。实际上这个实现存在于 tensorflow 中,所以我想知道是否可以将其包装为 keras 层并在模型中调用它。
我发现官方documentation 太简单了,看不到如何构建自定义 RNN 层。 here 和 here 也有类似的问题,但似乎没有得到解决。
提前感谢您的帮助!
【问题讨论】:
【参考方案1】:现在 tensorflow 的文档可能在问题发布后有所改进。
您可能需要查看this guide 或this SO answer 以供参考。
【讨论】:
【参考方案2】:根据我的理解,您应该能够在类层的 init() 中初始化单元格,然后在调用方法中使用您的输入引用它。
例如:
class MySimpleLayer(Layer):
def __init__(self, lstm_size):
super(MySimpleLayer, self).__init__()
self.lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
def call(self, batch, state):
return self.lstm(batch, state)
layer = MySimpleLayer(lstm_size)
logits = layer(batch, state)
这个实现是最基本的,所以你可能需要研究 build() 和 compute_output_shape() 方法来处理更复杂的用例。
【讨论】:
对不起,call
这样的定义与Layer
不匹配;我得到 TypeError: call() missing 1 required positional argument: 'states'
Call() 肯定与 Layer 类一起使用,正如 here 所指定的那样。对我来说,这看起来像是一个实施错误。尝试将 [batch, state] 作为单个列表输入传递给 call()。以上是关于如何在keras中包装张量流RNNCell?的主要内容,如果未能解决你的问题,请参考以下文章
如何在张量流中将 TextVectorization 保存到磁盘?