pytorch实现断点续训

Posted code_kd

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch实现断点续训相关的知识,希望对你有一定的参考价值。

可以在训练完每个epoch之后,保存下epoch,optimizer,net的信息。

checkpoint = 
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(checkpoint, CHECKPOINT_FILE)

如果需要从上次状态接着训练的话:

if resume:
        # 恢复上次的训练状态
        print("Resume from checkpoint...")
        checkpoint = torch.load(CHECKPOINT_FILE)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch = checkpoint['epoch']+1
        #从上次记录的损失和正确率接着记录
        dict = torch.load(ACC_LOSS_FILE)
        loss_record = dict['loss']
        acc_record = dict['acc']

如果使用了scheduler的话,需要更改其中的last_epoch,保证学习率也随之更新。

以上是关于pytorch实现断点续训的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch分布式训练与断点续训

PyTorch保存模型断点以及加载断点继续训练

第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像

第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像

基于pytorch实现简单的分类模型训练

深度学习理论与实战PyTorch实现