未雨绸缪:随手保存 PyTorch 训练模型

Posted 集智学园

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了未雨绸缪:随手保存 PyTorch 训练模型相关的知识,希望对你有一定的参考价值。

我们都知道,训练一个深度神经网络是需要挺长的时间的,即使是在高性能服务器上,有些训练也要持续几天之久。

不知道大家有没有遇到这种尴尬的情况:花了一天时间好不容易训练模型到 60% 啦,突然,机房要停电?学长要占用服务器?购买的 GPU 计算时间用完了等等。

怎么办??未雨绸缪:随手保存 PyTorch 训练模型

训练了一大半的模型不能功亏一篑呀,能不能把没训练完成的模型先保存下来,回头有机会了再加载接着训练?

那么今天我就给大家来介绍一个小技巧。

教你如何将未训练完成的 PyTorch 模型保存下来,而且不只是模型,训练过程中的优化器(optimizer),迭代数(epochs),以及正确率(score)等等,都可以以文件的形式保存下来,并在未来继续加载训练。

下面重点来啦!

(敲黑板)

我们利用 torch.save方法,使用这个方法来保存模型。

首先让我们封装一个用来保存模型的函数。

它有三个参数:

第一个参数就是我们要保存的模型状态以及优化器的状态,迭代次数等信息。

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):

    torch.save(state, filename)

    if is_best:

        shutil.copyfile(filename, 'model_best.pth.tar')

然后我们用下面的方式来调用它:

save_checkpoint({

    'epoch': epoch + 1,

    'arch': args.arch,

    'state_dict': model.state_dict(),

    'best_prec1': best_prec1,

    'optimizer' : optimizer.state_dict(),

}, is_best)

这样我们就将模型保存为一个文件。

那要怎么重新加载呢?

像这样:

if args.resume:

    if os.path.isfile(args.resume):

        print("=> loading checkpoint '{}'".format(args.resume))

        # 通过参数指定要加载的模型文件名

        checkpoint = torch.load(args.resume)

        # 读取出保存的模型训练参数

        args.start_epoch = checkpoint['epoch']

        best_prec1 = checkpoint['best_prec1']

        # 重新加载模型训练进度

        model.load_state_dict(checkpoint['state_dict'])

        # 重新加载优化器进度

        optimizer.load_state_dict(checkpoint['optimizer'])

        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

    else:

        print("=> no checkpoint found at '{}'".format(args.resume))

这样就可以保存(save)&加载(load)训练中的 PyTorch 模型啦!善用SL(save&load)大法,不但不怕服务器突然掉链子,还能够把训练各个阶段的模型都保存下来,用于研究模型训练的各个步骤。

未雨绸缪:随手保存 PyTorch 训练模型

本节中的全部代码取自 PyTorch 的 ImageNet 官方代码,你可以在这里找到完整代码。

https://github.com/pytorch/examples/blob/master/imagenet/main.py#L139

本文编译自 PyTorch 官方论坛,原址:

https://discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/2610/7

未雨绸缪:随手保存 PyTorch 训练模型


推荐阅读:

为什么他们要来集智AI学园学习 PyTorch?

为什么机器学习研究者都投入了 PyTorch 的怀抱?

重磅系列课:火炬上的深度学习(下)



获取更多更有趣的AI教程吧!

学园网站:campus.swarma.org



 商务合作|zhangqian@swarma.org     

投稿转载|wangjiannan@swarma.org


点击学习PyTorch

以上是关于未雨绸缪:随手保存 PyTorch 训练模型的主要内容,如果未能解决你的问题,请参考以下文章

在 pytorch 中为聊天机器人加载经过训练的模型保存

PyTorch下载的预训练模型的保存位置(Windows)

在 PyTorch 中保存训练模型的最佳方法是啥? [关闭]

pytorch量化感知训练(QAT)示例---ResNet

Pytorch模型训练&保存/加载(搭建完整流程)

[深度学习] Pytorch—— 多/单GPUCPU,训练保存加载模型参数问题