Tensorflow GRU 层调用()参数——TypeError:调用()得到了一个意外的关键字参数“reset_after”
Posted
技术标签:
【中文标题】Tensorflow GRU 层调用()参数——TypeError:调用()得到了一个意外的关键字参数“reset_after”【英文标题】:Tensorflow GRU layer call() arguments -- TypeError: call() got an unexpected keyword argument 'reset_after' 【发布时间】:2021-11-15 21:27:05 【问题描述】:我实现了一个带有 GRU 层的模型,该模型及其训练工作正常
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
我只是更改了 GRU 层的定义,使其 (1) CuDNN 兼容 (2) 添加 dropout
在我保留的模型定义中
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
在我设置的调用函数中
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training,
reset_after=True, recurrent_activation='sigmoid', # to make it more GPU friendly
recurrent_dropout=0.2, dropout=0.2 # to add some dropout to it
)
似乎遵守了Keras 或Tensorflow 准则,但我收到此错误
Traceback (most recent call last):
File "rnn_train_004.py", line 125, in <module>
example_batch_predictions = model(input_example_batch)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 1037, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "rnn_train_004.py", line 107, in call
recurrent_dropout=0.2, dropout=0.2 # to add some dropout to it
File "/usr/local/lib/python3.6/dist-packages/keras/layers/recurrent.py", line 716, in __call__
return super(RNN, self).__call__(inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 1037, in __call__
outputs = call_fn(inputs, *args, **kwargs)
TypeError: call() got an unexpected keyword argument 'reset_after'
【问题讨论】:
大部分参数(reset_after、recurrent_activation、recurrent_dropout、dropout)必须传递给构造函数。您将它们传递给call
。
【参考方案1】:
将参数传递给构造函数而不是 call() 方法
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True,
reset_after=True,
recurrent_activation='sigmoid', # to make it more GPU friendly
recurrent_dropout=0.2,
dropout=0.2 # to add some dropout to it
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
【讨论】:
以上是关于Tensorflow GRU 层调用()参数——TypeError:调用()得到了一个意外的关键字参数“reset_after”的主要内容,如果未能解决你的问题,请参考以下文章
Keras - 具有经常丢失的 GRU 层 - 损失:'nan',准确度:0
将 GRU 层从 PyTorch 转换为 TensorFlow
DL之GRU(Tensorflow框架):基于茅台股票数据集利用GRU算法实现回归预测(保存模型.ckpt.index.ckpt.data文件)