Pytorch 保存模型用户警告:无法检索网络类型容器的源代码
Posted
技术标签:
【中文标题】Pytorch 保存模型用户警告:无法检索网络类型容器的源代码【英文标题】:Pytorch saving model UserWarning: Couldn't retrieve source code for container of type Network 【发布时间】:2019-02-16 00:38:51 【问题描述】:使用 Pytorch 保存模型时:
torch.save(model, 'checkpoint.pth')
我收到以下警告:
/opt/conda/lib/python3.6/site-packages/torch/serialization.py:193: 用户警告:无法检索类型容器的源代码 网络。加载时不会检查其正确性。 “类型” + obj.name + ". 不会被勾选 "
当我加载它时,我收到以下错误:
state_dict = torch.load('checkpoint_state_dict.pth')
model = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)
AttributeError Traceback (most recent call last)
<ipython-input-2-6a79854aef0f> in <module>()
2 state_dict = torch.load('checkpoint_state_dict.pth')
3 model = 0
----> 4 model = torch.load('checkpoint.pth')
5 model.load_state_dict(state_dict)
/opt/conda/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module)
301 f = open(f, 'rb')
302 try:
--> 303 return _load(f, map_location, pickle_module)
304 finally:
305 if new_fd:
/opt/conda/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module)
467 unpickler = pickle_module.Unpickler(f)
468 unpickler.persistent_load = persistent_load
--> 469 result = unpickler.load()
470
471 deserialized_storage_keys = pickle_module.load(f)
AttributeError: Can't get attribute 'Network' on <module '__main__'>
为什么无法保存模型并完全重新加载?
【问题讨论】:
似乎 pytorch 找不到您对 NN 模型的定义。 【参考方案1】:保存
torch.save('state_dict': model.state_dict(), 'checkpoint.pth.tar')
加载中
model = describe_model()
checkpoint = torch.load('checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
【讨论】:
谢谢,但是我在哪里可以找到 describe_model() ? @DarioFederici 这是你的模特。您在训练之前在代码中定义了这一点。您需要再次定义模型以将检查点加载到您的模型。 好的,所以 describe_model() 是 NN 定义。 @DarioFederici 。是的,您的神经网络,首先定义模型,然后将权重上传到其中。 执行上述代码 "torch.save('state_dict': model.state_dict(),...)" 返回 TypeError: 'collections.OrderedDict' 对象不可调用。以上是关于Pytorch 保存模型用户警告:无法检索网络类型容器的源代码的主要内容,如果未能解决你的问题,请参考以下文章