Libtorch:无法加载跟踪的 lstm 脚本模型

Posted

技术标签:

【中文标题】Libtorch:无法加载跟踪的 lstm 脚本模型【英文标题】:Libtorch: cannot load traced lstm scriptmodel 【发布时间】:2019-08-16 13:23:24 【问题描述】:

我保存了一个 pytorch ScriptModule 并使用 libtorch 加载它。但是我遇到了以下问题

我用的是win10下的linux子系统,我用的是pytorch 1.2。

要重现我的问题,您可以运行这段 python 代码来保存一个 pt 模型

import torch
import torch.nn as nn


# TODO: https://github.com/pytorch/pytorch/issues/23930
class test(torch.jit.ScriptModule):

    def __init__(self, vocab_size=10, rnn_dims=512):
        super().__init__()
        self.word_embeds = nn.Embedding(vocab_size, rnn_dims)
        self.emb_drop = nn.Dropout(0.1)
        self.rnn = nn.LSTM(input_size=rnn_dims, hidden_size=rnn_dims, batch_first=True,
                           num_layers=2, dropout=0.1)
        # delattr(self.rnn, 'forward_packed')

    @torch.jit.script_method
    def forward(self, x):
        h1 = (torch.zeros(2, 1, 512), torch.zeros(2, 1, 512))
        embeds = self.emb_drop(self.word_embeds(x))
        out, h1 = self.rnn(embeds, h1)

        return h1


model = test()

input = torch.ones((1,3)).long()
output = model(input)
print('output', output)

# torch.onnx.export(model,  # model being run
#                   input,
#                   'test.onnx',
#                   example_outputs=output)
#torch.jit.trace(model, (torch.ones((1,3)).long(), torch.ones((3,1))), check_trace=False)
model.save('lstm_test.pt')

然后在 libtorch 中加载模型。

我不知道为什么会出现此错误。我根本不使用 PackedSequence。希望有人能帮帮我。

【问题讨论】:

【参考方案1】:

我现在知道出了什么问题。 libtorch 版本是官网的错误版本。现在我使用正确的 libtorch 1.2 就可以了。参考issuehttps://github.com/pytorch/pytorch/issues/24382

【讨论】:

以上是关于Libtorch:无法加载跟踪的 lstm 脚本模型的主要内容,如果未能解决你的问题,请参考以下文章

libtorch 模型加密

libtorch (pytorch c++) 教程

libtorch (pytorch c++) 教程

libtorch(pytorch c++)教程

libtorch(pytorch c++)教程

在动作脚本 3.0 中使用共享对象加载显示对象数组