中断后如何恢复训练 pl.Trainer?

Posted

技术标签:

【中文标题】中断后如何恢复训练 pl.Trainer?【英文标题】:How can I resume training pl.Trainer after interruption? 【发布时间】:2021-05-31 10:11:03 【问题描述】:

我有 Model 和 Trainer pytorch-lightning 对象,它们的初始化如下:

checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join('experiments', experiment_name, 'epoch-avg_valid_iou:.4f'),
    save_top_k=1,
    verbose=True,
    monitor='avg_valid_iou',
    mode='max',
    prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

然后我开始使用:

trainer.fit(model)

但是训练被中断了,现在我想使用第 N 次迭代的检查点来恢复它 所以我尝试将模型和训练器初始化为:

model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt', 
                  checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

但一次又一次地从头开始训练(从第 0 个 epoch 和巨大的错误开始)。我错过了什么?

【问题讨论】:

【参考方案1】:

您不仅应该保存模型状态,还应该保存优化器状态和起始 epoch 值。例如:

state(
       'epoch': epoch + 1,
       'state_dict': model.module.state_dict(),
       'optimizer': optimizer.state_dict(),
      )

保存检查点后,您可以通过以下命令恢复训练:

checkpoint = torch.load(state_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_val = checkpoint['epoch']

for epoch in range(start_val, max_val):
   ...
   ...

【讨论】:

感谢您的回复但实际上应该在循环中“.... ....”。我也在使用闪电,而不是 pytorch 我在代码中也提到了我初始化ModelCheckpoint的方式。

以上是关于中断后如何恢复训练 pl.Trainer?的主要内容,如果未能解决你的问题,请参考以下文章

从可用区中断中自动恢复?

如何使用ssd训练自己的数据

音频会话中断后恢复 twilio 通话

AVAudioPlayer 中断后从相同状态恢复音频剪辑

如何在 Huggingface Trainer 课程中恢复训练时避免迭代 Dataloader?

Spring Cloud Bus 网络中断后无法恢复