Keras 如何处理单元格和隐藏状态(RNN、LSTM)的初始值以进行推理?

Posted

技术标签:

【中文标题】Keras 如何处理单元格和隐藏状态(RNN、LSTM)的初始值以进行推理?【英文标题】:What does Keras do with the initial values of cell & hidden states (RNN, LSTM) for inference? 【发布时间】:2019-07-25 15:58:33 【问题描述】:

假设训练完成:Keras 在推理时(在 LSTM 和 RNN 层中)对第 0 个单元状态和隐藏状态使用什么值?我至少可以想到三种情况,但在文档中找不到任何确凿的答案:

(a) 初始状态被学习,然后用于所有预测

(b) 或者初始状态总是设置为零

(c) 初始状态总是随机的(希望不是……?)

【问题讨论】:

【参考方案1】:

如果使用LSTM(stateful=True),隐藏状态初始化为零,用fitpredict 更改,并保持不变,直到调用.reset_states()。如果LSTM(stateful=False),则在每批拟合/预测/等后重置状态。

这可以从.reset_states()source code验证,并通过直接检查;两者都适用于下面的stateful=True。有关如何传递状态的更多信息,请参阅this answer。


直接检查

batch_shape = (2, 10, 4)
model = make_model(batch_shape)

X = np.random.randn(*batch_shape)
y = np.random.randint(0, 2, (batch_shape[0], 1))

show_lstm_states("STATES INITIALIZED")
model.train_on_batch(X, y)

show_lstm_states("STATES AFTER TRAIN")
model.reset_states()
show_lstm_states("STATES AFTER RESET")

model.predict(X)
show_lstm_states("STATES AFTER PREDICT")

输出

STATES INITIALIZED
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]

STATES AFTER TRAIN
[[0.12061571 0.03639204 0.20810013 0.05309075]
 [0.01832913 0.00062357 0.10566339 0.60108346]]
[[0.21241754 0.0773523  0.37392718 0.15590034]
 [0.08496398 0.00112716 0.23814857 0.95995367]]

STATES AFTER RESET
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]]

STATES AFTER PREDICT
[[0.12162527 0.03720453 0.20628096 0.05421837]
 [0.01849432 0.00064993 0.1045063  0.6097021 ]]
[[0.21398112 0.07894284 0.3709934  0.15928769]
 [0.08605779 0.00117485 0.23606434 0.97212094]]

使用的函数/导入

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
import numpy as np

def make_model(batch_shape):
    ipt = Input(batch_shape=batch_shape)
    x   = LSTM(4, stateful=True, activation='relu')(ipt)
    out = Dense(1, activation='sigmoid')(x)

    model = Model(ipt, out)
    model.compile('adam', 'binary_crossentropy')

    return model

def show_lstm_states(txt=''):
    print('\n' + txt) 
    states = model.layers[1].states

    for state in states:
        if tf.__version__[0] == '2':
            print(state.numpy())
        else:
            print(K.get_value(state))

检查源代码

from inspect import getsource
print(getsource(model.layers[1].reset_states))

【讨论】:

【参考方案2】:

我对@9​​87654321@ 的理解是在大多数情况下它们被初始化为零。

【讨论】:

如果只是大多数种情况,那么它是未定义的。

以上是关于Keras 如何处理单元格和隐藏状态(RNN、LSTM)的初始值以进行推理?的主要内容,如果未能解决你的问题,请参考以下文章

如何将 Pandas Dataframe 转换为 Keras RNN 以解决多变量分类问题

Keras 如何处理多标签分类?

pytorch中如何处理RNN输入变长序列padding

Pytorch 中如何处理 RNN 输入变长序列 padding

深度学习实战pytorch中如何处理RNN输入变长序列padding

如何处理keras的单输出多重损失?