手动从 Tensorflow 导入 LSTM 到 PyTorch
Posted
技术标签:
【中文标题】手动从 Tensorflow 导入 LSTM 到 PyTorch【英文标题】:Import LSTM from Tensorflow to PyTorch by hand 【发布时间】:2019-08-09 07:25:51 【问题描述】:我正在尝试将预训练模型从 tensorflow 导入 PyTorch。它接受单个输入并将其映射到单个输出。 当我尝试导入 LSTM 权重时,出现了混乱
我使用以下函数从文件中读取权重及其变量:
def load_tf_model_weights():
modelpath = 'models/model1.ckpt.meta'
with tf.Session() as sess:
tf.train.import_meta_graph(modelpath)
init = tf.global_variables_initializer()
sess.run(init)
vars = tf.trainable_variables()
W = sess.run(vars)
return W,vars
W,V = load_tf_model_weights()
然后我正在检查权重的形状
In [33]: [w.shape for w in W]
Out[33]: [(51, 200), (200,), (100, 200), (200,), (50, 1), (1,)]
此外,变量定义为
In [34]: V
Out[34]:
[<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(51, 200) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(200,) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(100, 200) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(200,) dtype=float32_ref>,
<tf.Variable 'weight:0' shape=(50, 1) dtype=float32_ref>,
<tf.Variable 'FCLayer/Variable:0' shape=(1,) dtype=float32_ref>]
所以我可以说W
的第一个元素定义了 LSTM 的内核,第二个元素定义了它的偏差。根据this post,内核的形状定义为
[input_depth + h_depth, 4 * self._num_units]
偏差为[4 * self._num_units]
。我们已经知道input_depth
是1
。所以我们得到,h_depth
和 _num_units
都有 50
的值。
在 pytorch 中,我想要为其分配权重的 LSTMCell 如下所示:
In [38]: cell = nn.LSTMCell(1,50)
In [39]: [p.shape for p in cell.parameters()]
Out[39]:
[torch.Size([200, 1]),
torch.Size([200, 50]),
torch.Size([200]),
torch.Size([200])]
前两个条目可以被W
的第一个值覆盖,其形状为(51,200)
。但是来自 Tensorflow 的 LSTMCell 只产生一个形状偏差 (200)
而 pytorch 想要其中两个
通过排除偏差,我还剩下了权重:
cell2 = nn.LSTMCell(1,50,bias=False)
[p.shape for p in cell2.parameters()]
Out[43]: [torch.Size([200, 1]), torch.Size([200, 50])]
谢谢!
【问题讨论】:
【参考方案1】:pytorch 使用 CuDNN 的 LSTM 底层(即使你没有 CUDA,它仍然使用兼容的东西)因此它有一个额外的偏置项。
因此,您可以选择两个总和等于 1 的数字(0 和 1、1/2 和 1/2 或其他任何值),并将您的 pytorch 偏差设置为这些数字乘以 TF 的偏差。
pytorch_bias_1 = torch.from_numpy(alpha * tf_bias_data)
pytorch_bias_2 = torch.from_numpy((1.0-alpha) * tf_bias_data)
【讨论】:
以上是关于手动从 Tensorflow 导入 LSTM 到 PyTorch的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
TensorFlow搭建LSTM实现时间序列预测(负荷预测)
如何从 pandas 数据帧在 tensorflow v1 中实现 LSTM