PyTorch 模型保存错误:“无法腌制本地对象”

Posted

技术标签:

【中文标题】PyTorch 模型保存错误:“无法腌制本地对象”【英文标题】:PyTorch model saving error: "Can't pickle local object" 【发布时间】:2020-05-29 16:36:55 【问题描述】:

当我尝试使用这段代码保存 PyTorch 模型时:

checkpoint = 'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()
torch.save(checkpoint, 'Checkpoint.pth')

我收到以下错误:

    E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...

      "type " + obj.__name__ + ". It won't be checked "
    Can't pickle local object 'trainModel.<locals>.Net'

当我尝试使用这段代码保存 PyTorch 模型时:

checkpoint = 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()
torch.save(checkpoint, 'Checkpoint.pth')

我没有收到任何错误,但我想保存 ANN 类。我怎么解决这个问题?另外,我可以在之前的其他项目中保存具有第一个结构的模型

【问题讨论】:

您是否尝试将dill 作为泡菜模块提供给torch.save 【参考方案1】:

你不能! torch.save 仅保存对象 state_dict()

当您使用以下内容时:

checkpoint = 'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()
torch.save(checkpoint, 'Checkpoint.pth')

您正在尝试保存模型本身,但此数据保存在 model.state_dict() 中,当使用 state_dict 加载模型时,您应该首先启动模型对象。

这正是第二种方法正常工作的原因:

checkpoint = 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()
torch.save(checkpoint, 'Checkpoint.pth')

我建议阅读以下链接中有关如何正确保存\加载模型的 pytorch 文档: https://pytorch.org/tutorials/beginner/saving_loading_models.html

【讨论】:

感谢您的回答。那么有什么方法可以存储类吗?我必须每次都定义类吗? @GöktuğYıldırım 您应该将网络类与主脚本分开,然后在需要的地方导入它【参考方案2】:

使用通常的正确方法来保存和加载模型https://pytorch.org/tutorials/beginner/saving_loading_models.html,如果您有要保存的 args 或 dicts,也许还有一个 lambda 函数,有时我使用 dill 并且错误消失了。例如

def save_for_meta_learning(args, ckpt_filename='ckpt.pt'):
    if is_lead_worker(args.rank):
        import dill
        args.logger.save_current_plots_and_stats()
        # - ckpt
        assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
        args_pickable = uutils.make_args_pickable(args)
        # args.meta_learner.args = args_pickable
        f: nn.Module = get_model_from_ddp(args.base_model)
        # pickle vs torch_uu.save https://discuss.pytorch.org/t/advantages-disadvantages-of-using-pickle-module-to-save-models-vs-torch-save/79016
        torch.save('training_mode': args.training_mode,  # its or epochs
                    'it': args.it,
                    'epoch_num': args.epoch_num,
                    # 'args': args_pickable,
                    'args_pickable': args_pickable,
                    # 'meta_learner': args.meta_learner,
                    'meta_learner_str': str(args.meta_learner),
                    # 'f': f,
                    'f_state_dict': f.state_dict(),
                    'f_str': str(f),
                    # 'f_modules': f._modules,
                    # 'f_modules_str': str(f._modules),
                    'outer_opt_state_dict': args.outer_opt.state_dict()
                    ,
                   pickle_module=dill,
                   f=args.log_root / ckpt_filename)

【讨论】:

以上是关于PyTorch 模型保存错误:“无法腌制本地对象”的主要内容,如果未能解决你的问题,请参考以下文章

Python multiprocessing basic - 无法腌制本地对象并用尽输入

Python多处理:AttributeError:无法腌制本地对象

如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它

Pytorch 之 模型的保存与调用

Pytorch模型保存与加载,并在加载的模型基础上继续训练

[Pytorch]Pytorch 保存模型与加载模型(转)