保存和加载 Pytorch 模型检查点以进行推理不起作用
Posted
技术标签:
【中文标题】保存和加载 Pytorch 模型检查点以进行推理不起作用【英文标题】:Saving and Loading Pytorch Model Checkpoint for inference not working 【发布时间】:2019-06-13 04:27:37 【问题描述】:我有一个使用 LSTM 训练的模型。该模型是在 GPU 上训练的(在 Google COLABORATORY 上)。 我必须保存模型以进行推理;我将在 CPU 上运行。 训练完成后,我将模型检查点保存如下:
torch.save('model_state_dict': model.state_dict(),'lstmmodelgpu.tar')
并且,为了推断,我将模型加载为:
# model definition
vocab_size = len(vocab_to_int)+1
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2
model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)
# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
但是,它引发了以下错误:
model.load_state_dict(checkpoint['model_state_dict'])
File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
Missing key(s) in state_dict: "embedding.weight".
Unexpected key(s) in state_dict: "encoder.weight".
保存检查点时有什么遗漏的吗?
【问题讨论】:
你在用DataParallel
吗?
@harshit_k,不,我没有使用 DataParallel。我已按照教程进行操作:pytorch.org/tutorials/beginner/…
不知道为什么会报错,但您仍然可以尝试discuss.pytorch.org/t/… 中给出的解决方案或使用model.module.load_state_dict(checkpoint['model_state_dict'])
。但是,我不确定这是否会奏效。
【参考方案1】:
这里有两点需要考虑。
您提到您正在 GPU 上训练模型并使用它在 CPU 上进行推理,因此您需要在 load 函数中添加参数 map_location传递 torch.device('cpu')。
state_dict 键不匹配(在您的输出消息中指示),这可能是由于某些缺少键或您正在加载的 state_dict 中的键多于您所加载的模型目前使用。为此,您必须在 load_state_dict 函数中添加一个值为 False 的参数 strict。这将使方法忽略键的不匹配。
旁注:尝试对检查点文件使用 pt 或 pth 的扩展名,因为这是惯例。
【讨论】:
以上是关于保存和加载 Pytorch 模型检查点以进行推理不起作用的主要内容,如果未能解决你的问题,请参考以下文章
在 PyTorch 中加载迁移学习模型进行推理的正确方法是啥?