Pytorch 保存模型用户警告:无法检索网络类型容器的源代码

Posted

技术标签:

【中文标题】Pytorch 保存模型用户警告:无法检索网络类型容器的源代码【英文标题】:Pytorch saving model UserWarning: Couldn't retrieve source code for container of type Network 【发布时间】:2019-02-16 00:38:51 【问题描述】:

使用 Pytorch 保存模型时:

torch.save(model, 'checkpoint.pth')

我收到以下警告:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py:193: 用户警告:无法检索类型容器的源代码 网络。加载时不会检查其正确性。 “类型” + obj.name + ". 不会被勾选 "

当我加载它时,我收到以下错误:

state_dict = torch.load('checkpoint_state_dict.pth')
model = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)


AttributeError                            Traceback (most recent call last)
<ipython-input-2-6a79854aef0f> in <module>()
      2 state_dict = torch.load('checkpoint_state_dict.pth')
      3 model = 0
----> 4 model = torch.load('checkpoint.pth')
      5 model.load_state_dict(state_dict)

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module)
    301         f = open(f, 'rb')
    302     try:
--> 303         return _load(f, map_location, pickle_module)
    304     finally:
    305         if new_fd:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module)
    467     unpickler = pickle_module.Unpickler(f)
    468     unpickler.persistent_load = persistent_load
--> 469     result = unpickler.load()
    470 
    471     deserialized_storage_keys = pickle_module.load(f)

AttributeError: Can't get attribute 'Network' on <module '__main__'>

为什么无法保存模型并完全重新加载?

【问题讨论】:

似乎 pytorch 找不到您对 NN 模型的定义。 【参考方案1】:

保存

torch.save('state_dict': model.state_dict(), 'checkpoint.pth.tar')

加载中

model = describe_model()
checkpoint = torch.load('checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

【讨论】:

谢谢,但是我在哪里可以找到 describe_model() ? @DarioFederici 这是你的模特。您在训练之前在代码中定义了这一点。您需要再次定义模型以将检查点加载到您的模型。 好的,所以 describe_model() 是 NN 定义。 @DarioFederici 。是的,您的神经网络,首先定义模型,然后将权重上传到其中。 执行上述代码 "torch.save('state_dict': model.state_dict(),...)" 返回 TypeError: 'collections.OrderedDict' 对象不可调用。

以上是关于Pytorch 保存模型用户警告:无法检索网络类型容器的源代码的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch]Pytorch 保存模型与加载模型(转)

在无法访问模型类代码的情况下保存 PyTorch 模型

PyTorch 模型保存错误:“无法腌制本地对象”

将pytorch模型转为coreml后,预测结果差很多

PyTorch:模型save和load

PyTorch:模型save和load