如何使用有状态 LSTM 模型进行预测,而不指定与我训练时相同的 batch_size?
Posted
技术标签:
【中文标题】如何使用有状态 LSTM 模型进行预测,而不指定与我训练时相同的 batch_size?【英文标题】:How can I use a stateful LSTM model to predict without specifying the same batch_size as I trained it? 【发布时间】:2020-03-07 00:19:49 【问题描述】:我尝试设置 stateful=True 来训练我的 LSTM 模型并且它有效。
但我必须将我的输入重新整形为我为第一层设置的相同 batch_size,这是有状态 RNN 必须的,否则我会收到错误:InvalidArgumentError: Invalid input_h shape。
我将batch_size设置为64,但我只想输入一个起始句来生成文本。如果我必须提供batch_size=64的输入,我需要准备64个句子,这很荒谬。
如果我没有设置 stateful=True 效果很好,但我需要提高性能。 在这种情况下,如果不匹配我设置的batch_size,如何使用有状态的LSTM模型?
我定义的模型
seq_length = 100
batch_size = 64
epochs = 3
vocab_size = len(vocab) # 65
embedding_dim = 256
rnn_units = 1024
def bi_lstm(vocab_size, embedding_dim, batch_size, rnn_units):
model = keras.models.Sequential([
keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape = (batch_size, None)),
keras.layers.Bidirectional(
keras.layers.LSTM(units = rnn_units,
return_sequences = True,
stateful = True,
recurrent_initializer = "glorot_uniform"
)),
keras.layers.Dense(vocab_size),
])
return model
我做了一个这样的简单测试,它显示了错误。
for x, y in seq_dataset.take(1):
x = x[:-10,:] # change the batch size from 64 to 54, it worked well if I del this line
print(x.shape)
pred = model(x)
print(pred.shape)
InvalidArgumentError Traceback (most recent call last)
<ipython-input-98-99323ee3e09d> in <module>()
2 x = x[:-10,:]
3 print(x.shape)
----> 4 pred = model(x)
5 print(pred.shape)
14 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
889 with base_layer_utils.autocast_context_manager(
890 self._compute_dtype):
--> 891 outputs = self.call(cast_inputs, *args, **kwargs)
892 self._handle_activity_regularization(inputs, outputs)
893 self._set_mask_metadata(inputs, outputs, input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/sequential.py in call(self, inputs, training, mask)
254 if not self.built:
255 self._init_graph_network(self.inputs, self.outputs, name=self.name)
--> 256 return super(Sequential, self).call(inputs, training=training, mask=mask)
257
258 outputs = inputs # handle the corner case where self.layers is empty
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/network.py in call(self, inputs, training, mask)
706 return self._run_internal_graph(
707 inputs, training=training, mask=mask,
--> 708 convert_kwargs_to_constants=base_layer_utils.call_context().saving)
709
710 def compute_output_shape(self, input_shape):
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants)
858
859 # Compute outputs.
--> 860 output_tensors = layer(computed_tensors, **kwargs)
861
862 # Update tensor_dict.
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/wrappers.py in __call__(self, inputs, initial_state, constants, **kwargs)
526
527 if initial_state is None and constants is None:
--> 528 return super(Bidirectional, self).__call__(inputs, **kwargs)
529
530 # Applies the same workaround as in `RNN.__call__`
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
889 with base_layer_utils.autocast_context_manager(
890 self._compute_dtype):
--> 891 outputs = self.call(cast_inputs, *args, **kwargs)
892 self._handle_activity_regularization(inputs, outputs)
893 self._set_mask_metadata(inputs, outputs, input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/wrappers.py in call(self, inputs, training, mask, initial_state, constants)
640
641 y = self.forward_layer(forward_inputs,
--> 642 initial_state=forward_state, **kwargs)
643 y_rev = self.backward_layer(backward_inputs,
644 initial_state=backward_state, **kwargs)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
621
622 if initial_state is None and constants is None:
--> 623 return super(RNN, self).__call__(inputs, **kwargs)
624
625 # If any of `initial_state` or `constants` are specified and are Keras
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
889 with base_layer_utils.autocast_context_manager(
890 self._compute_dtype):
--> 891 outputs = self.call(cast_inputs, *args, **kwargs)
892 self._handle_activity_regularization(inputs, outputs)
893 self._set_mask_metadata(inputs, outputs, input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent_v2.py in call(self, inputs, mask, training, initial_state)
959 if can_use_gpu:
960 last_output, outputs, new_h, new_c, runtime = cudnn_lstm(
--> 961 **cudnn_lstm_kwargs)
962 else:
963 last_output, outputs, new_h, new_c, runtime = standard_lstm(
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent_v2.py in cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards)
1172 outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
1173 inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
-> 1174 rnn_mode='lstm')
1175
1176 last_output = outputs[-1]
/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/gen_cudnn_rnn_ops.py in cudnn_rnn(input, input_h, input_c, params, rnn_mode, input_mode, direction, dropout, seed, seed2, is_training, name)
107 input_mode=input_mode, direction=direction, dropout=dropout,
108 seed=seed, seed2=seed2, is_training=is_training, name=name,
--> 109 ctx=_ctx)
110 except _core._SymbolicException:
111 pass # Add nodes to the TensorFlow graph.
/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/gen_cudnn_rnn_ops.py in cudnn_rnn_eager_fallback(input, input_h, input_c, params, rnn_mode, input_mode, direction, dropout, seed, seed2, is_training, name, ctx)
196 "is_training", is_training)
197 _result = _execute.execute(b"CudnnRNN", 4, inputs=_inputs_flat,
--> 198 attrs=_attrs, ctx=_ctx, name=name)
199 _execute.record_gradient(
200 "CudnnRNN", _inputs_flat, _attrs, _result, name)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Invalid input_h shape: [1,64,1024] [1,54,1024] [Op:CudnnRNN]
【问题讨论】:
【参考方案1】:stateful=True
时,确实需要batch_size
才能使模型的逻辑正常工作。
但是,您的模型的权重根本不需要知道batch_size
。所以,如果有一些set_batch_size()
方法会很好,或者更好,如果fit()
和predict()
可以从输入中得到它。但不幸的是,事实并非如此。
但有一个解决方法:只需定义该模型的另一个实例并指定batch_size=1
(或您希望的任何数字)。然后,只需将经过训练的模型的权重分配给这个具有不同批量大小的新模型:
model64 = bi_lstm(vocab_size, embedding_dim, batch_size=64, rnn_units=rnn_units)
model64.fit(...)
# optional: model64.save_weights('model64_weights.hdf5')
model1 = bi_lstm(vocab_size, embedding_dim, batch_size=1, rnn_units=rnn_units)
model1.set_weights(model64.get_weights()) # or: model1.load_weights('model64_weights.hdf5')
model1.predict(...)
这是因为batch_size
根本不参与权重的形状,因此它们是可以互换的。
【讨论】:
非常感谢!它真的解决了我的问题。很有帮助的建议! 巨大的帮助!! tnx :)以上是关于如何使用有状态 LSTM 模型进行预测,而不指定与我训练时相同的 batch_size?的主要内容,如果未能解决你的问题,请参考以下文章