Pytorch保存和重装模型

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch保存和重装模型相关的知识,希望对你有一定的参考价值。

我对VGG16型号有以下结构:

<bound method Module.state_dict of VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Network(
    (hidden_layers): ModuleList(
      (0): Linear(in_features=25088, out_features=4096, bias=True)
    )
    (output): Linear(in_features=4096, out_features=102, bias=True)
    (dropout): Dropout(p=0.5)
  )

使用以下代码保存模型时:

checkpoint = {'input_size': 25088,
              'output_size': 102,
              'hidden_layers': [each.out_features for each in model.hidden_layers],
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')

我收到以下错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-13-b4654570e6e8> in <module>()
      2 checkpoint = {'input_size': 25088,
      3               'output_size': 102,
----> 4               'hidden_layers': [each.out_features for each in model.hidden_layers],
      5               'state_dict': model.state_dict()}
      6 

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
    530                 return modules[name]
    531         raise AttributeError("'{}' object has no attribute '{}'".format(
--> 532             type(self).__name__, name))
    533 
    534     def __setattr__(self, name, value):

AttributeError:'VGG'对象没有属性'hidden_​​layers'

但是,VGG有hidden_​​layers。如何通过学习转移来保存VGG?

答案
torch.save(model.state_dict(), 'checkpoint.pth')

state_dict = torch.load('checkpoint.pth')
model.load_state_dict(state_dict)

以上是关于Pytorch保存和重装模型的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch 保存模型用户警告:无法检索网络类型容器的源代码

pytorch 保存和加载模型

PyTorch 模型保存错误:“无法腌制本地对象”

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存

Pytorch 之 模型的保存与调用

win7旗舰版系统下IE11无法卸载和重装怎么办