保存和加载 Pytorch 模型检查点以进行推理不起作用

Posted

技术标签:

【中文标题】保存和加载 Pytorch 模型检查点以进行推理不起作用【英文标题】:Saving and Loading Pytorch Model Checkpoint for inference not working 【发布时间】:2019-06-13 04:27:37 【问题描述】:

我有一个使用 LSTM 训练的模型。该模型是在 GPU 上训练的(在 Google COLABORATORY 上)。 我必须保存模型以进行推理;我将在 CPU 上运行。 训练完成后,我将模型检查点保存如下:

torch.save('model_state_dict': model.state_dict(),'lstmmodelgpu.tar')

并且,为了推断,我将模型加载为:

# model definition
vocab_size = len(vocab_to_int)+1 
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2

model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

但是,它引发了以下错误:

model.load_state_dict(checkpoint['model_state_dict'])
  File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
    Missing key(s) in state_dict: "embedding.weight". 
    Unexpected key(s) in state_dict: "encoder.weight".

保存检查点时有什么遗漏的吗?

【问题讨论】:

你在用DataParallel吗? @harshit_k,不,我没有使用 DataParallel。我已按照教程进行操作:pytorch.org/tutorials/beginner/… 不知道为什么会报错,但您仍然可以尝试discuss.pytorch.org/t/… 中给出的解决方案或使用model.module.load_state_dict(checkpoint['model_state_dict'])。但是,我不确定这是否会奏效。 【参考方案1】:

这里有两点需要考虑。

    您提到您正在 GPU 上训练模型并使用它在 CPU 上进行推理,因此您需要在 load 函数中添加参数 ma​​p_location传递 torch.device('cpu')

    state_dict 键不匹配(在您的输出消息中指示),这可能是由于某些缺少键或您正在加载的 state_dict 中的键多于您所加载的模型目前使用。为此,您必须在 load_state_dict 函数中添加一个值为 False 的参数 strict。这将使方法忽略键的不匹配。

旁注:尝试对检查点文件使用 pt 或 pth 的扩展名,因为这是惯例。

【讨论】:

以上是关于保存和加载 Pytorch 模型检查点以进行推理不起作用的主要内容,如果未能解决你的问题,请参考以下文章

在 PyTorch 中加载迁移学习模型进行推理的正确方法是啥?

Pytorch自定义加载预训练权重

PyTorch学习笔记 2. 运行官网训练推理的入门示例

PyTorch学习笔记 2. 运行官网训练推理的入门示例

PyTorch中通过torch.save保存模型和torch.load加载模型介绍

推理实践丨如何使用MindStudio进行Pytorch模型离线推理