使用 Keras,如何将 CuDNNLSTM 生成的权重加载到 LSTM 模型中?
Posted
技术标签:
【中文标题】使用 Keras,如何将 CuDNNLSTM 生成的权重加载到 LSTM 模型中?【英文标题】:Using Keras, How can I load weights generated from CuDNNLSTM into LSTM Model? 【发布时间】:2018-04-10 17:37:20 【问题描述】:我使用 Keras 开发了一个基于 LSTM 层的 NN 模型。为了提高 Paperspace(GPU 云处理基础设施)的速度,我将 LSTM 层换成了新的 CuDNNLSTM 层。然而,这只适用于支持 GPU cuDNN 的机器。 PS:CuDNNLSTM 仅在 Keras master
上可用,在最新版本中不可用。
所以我已经生成了权重并将它们保存为云上的hdf5
格式,我想在我的 MacBook 上本地使用它们。由于 CuDNNLSTM 层不可用,我只在本地安装时切换回 LSTM。
阅读此tweet about CuDNN from @fchollet 我认为它可以正常工作,只需将权重读回 LSTM 模型即可。
但是,当我尝试导入它们时,Keras 会抛出此错误:
Traceback (most recent call last):
...
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 0 in both shapes must be equal, but are 2048 and 4096 for 'Assign_2' (op: 'Assign') with input shapes: [2048], [4096].
...
ValueError: Dimension 0 in both shapes must be equal, but are 2048 and 4096 for 'Assign_2' (op: 'Assign') with input shapes: [2048], [4096]
用 h5cat 分析hdf5
文件我可以看到这两个结构是不同的。
TL;DR
我无法将 CuDNNLSTM 生成的权重加载到 LSTM 模型中。 我做错了什么吗?我怎样才能让它们无缝地工作?
这是我的模型:
SelectedLSTM = CuDNNLSTM if is_gpu_enabled() else LSTM
# ...
model = Sequential()
model.add(SelectedLSTM(HIDDEN_DIM, return_sequences=True, input_shape=(SEQ_LENGTH, vocab_size)))
model.add(Dropout(0.2))
model.add(SelectedLSTM(HIDDEN_DIM, return_sequences=False))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
【问题讨论】:
【参考方案1】:原因是CuDNNLSTM
层的bias
是LSTM
的两倍大。这是因为 cuDNN API 的底层实现。您可以将以下方程(从 cuDNN 用户指南复制)与通常的 LSTM 方程进行比较:
CuDNN 使用两个偏置项,因此偏置权重的数量增加了一倍。要将其转换回 LSTM
使用的值,需要将两个偏差项相加。
我已经提交了一个PR 来进行转换并且它被合并了。您可以从 GitHub 安装最新的 Keras,应该可以解决权重加载问题。
【讨论】:
【参考方案2】:只是添加到上面@Yu-Yang 的答案,最新的Keras 会自动将CuDMMLSTM
权重转换为LSTM
,但它不会为您改变您的.json 模型架构.
要在 LSTM 上运行推理,您需要打开 JSON 文件,并将 CuDNNLSTM
的所有实例手动更改为 LSTM
。然后运行model_from_json
加载您的模型,并运行load_weights
加载您的权重。
我一开始尝试在不手动更改CuDNNLSTM
模型的情况下运行load_weights
,但出现了一堆错误。
【讨论】:
以上是关于使用 Keras,如何将 CuDNNLSTM 生成的权重加载到 LSTM 模型中?的主要内容,如果未能解决你的问题,请参考以下文章
tensorflow.keras.layers:ImportError:无法导入名称“CuDNNLSTM”
Keras 中的 CuDNNLSTM 和 LSTM 有啥区别?
CuDNNLSTM:调用 ThenRnnForward 失败