pytorch实现断点续训
Posted code_kd
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch实现断点续训相关的知识,希望对你有一定的参考价值。
可以在训练完每个epoch之后,保存下epoch,optimizer,net的信息。
checkpoint =
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(checkpoint, CHECKPOINT_FILE)
如果需要从上次状态接着训练的话:
if resume:
# 恢复上次的训练状态
print("Resume from checkpoint...")
checkpoint = torch.load(CHECKPOINT_FILE)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
initepoch = checkpoint['epoch']+1
#从上次记录的损失和正确率接着记录
dict = torch.load(ACC_LOSS_FILE)
loss_record = dict['loss']
acc_record = dict['acc']
如果使用了scheduler的话,需要更改其中的last_epoch,保证学习率也随之更新。
以上是关于pytorch实现断点续训的主要内容,如果未能解决你的问题,请参考以下文章
第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像