Pytorch保存和重装模型
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch保存和重装模型相关的知识,希望对你有一定的参考价值。
我对VGG16型号有以下结构:
<bound method Module.state_dict of VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Network(
(hidden_layers): ModuleList(
(0): Linear(in_features=25088, out_features=4096, bias=True)
)
(output): Linear(in_features=4096, out_features=102, bias=True)
(dropout): Dropout(p=0.5)
)
使用以下代码保存模型时:
checkpoint = {'input_size': 25088,
'output_size': 102,
'hidden_layers': [each.out_features for each in model.hidden_layers],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
我收到以下错误:
AttributeError Traceback (most recent call last)
<ipython-input-13-b4654570e6e8> in <module>()
2 checkpoint = {'input_size': 25088,
3 'output_size': 102,
----> 4 'hidden_layers': [each.out_features for each in model.hidden_layers],
5 'state_dict': model.state_dict()}
6
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
530 return modules[name]
531 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 532 type(self).__name__, name))
533
534 def __setattr__(self, name, value):
AttributeError:'VGG'对象没有属性'hidden_layers'
但是,VGG有hidden_layers。如何通过学习转移来保存VGG?
答案
torch.save(model.state_dict(), 'checkpoint.pth')
state_dict = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)
以上是关于Pytorch保存和重装模型的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch 保存模型用户警告:无法检索网络类型容器的源代码