未雨绸缪:随手保存 PyTorch 训练模型
Posted 集智学园
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了未雨绸缪:随手保存 PyTorch 训练模型相关的知识,希望对你有一定的参考价值。
我们都知道,训练一个深度神经网络是需要挺长的时间的,即使是在高性能服务器上,有些训练也要持续几天之久。
不知道大家有没有遇到这种尴尬的情况:花了一天时间好不容易训练模型到 60% 啦,突然,机房要停电?学长要占用服务器?购买的 GPU 计算时间用完了等等。
怎么办??
训练了一大半的模型不能功亏一篑呀,能不能把没训练完成的模型先保存下来,回头有机会了再加载接着训练?
那么今天我就给大家来介绍一个小技巧。
教你如何将未训练完成的 PyTorch 模型保存下来,而且不只是模型,训练过程中的优化器(optimizer),迭代数(epochs),以及正确率(score)等等,都可以以文件的形式保存下来,并在未来继续加载训练。
好
下面重点来啦!
(敲黑板)
我们利用 torch.save方法,使用这个方法来保存模型。
首先让我们封装一个用来保存模型的函数。
它有三个参数:
第一个参数就是我们要保存的模型状态以及优化器的状态,迭代次数等信息。
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
然后我们用下面的方式来调用它:
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
这样我们就将模型保存为一个文件。
那要怎么重新加载呢?
像这样:
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
# 通过参数指定要加载的模型文件名
checkpoint = torch.load(args.resume)
# 读取出保存的模型训练参数
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
# 重新加载模型训练进度
model.load_state_dict(checkpoint['state_dict'])
# 重新加载优化器进度
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
这样就可以保存(save)&加载(load)训练中的 PyTorch 模型啦!善用SL(save&load)大法,不但不怕服务器突然掉链子,还能够把训练各个阶段的模型都保存下来,用于研究模型训练的各个步骤。
本节中的全部代码取自 PyTorch 的 ImageNet 官方代码,你可以在这里找到完整代码。
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L139
本文编译自 PyTorch 官方论坛,原址:
https://discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/2610/7
推荐阅读:
为什么他们要来集智AI学园学习 PyTorch?
为什么机器学习研究者都投入了 PyTorch 的怀抱?
重磅系列课:火炬上的深度学习(下)
获取更多更有趣的AI教程吧!
学园网站:campus.swarma.org
商务合作|zhangqian@swarma.org
投稿转载|wangjiannan@swarma.org
以上是关于未雨绸缪:随手保存 PyTorch 训练模型的主要内容,如果未能解决你的问题,请参考以下文章