通过模型检查点时 Pytorch 闪电出错

Posted

技术标签:

【中文标题】通过模型检查点时 Pytorch 闪电出错【英文标题】:Getting error with Pytorch lightning when passing model checkpoint 【发布时间】:2021-11-08 20:40:41 【问题描述】:

我正在使用 Hugging 人脸模型训练一个多标签分类问题。我正在使用 Pytorch Lightning 来训练模型。

代码如下:

当损失没有改善时,提前停止触发

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

我们可以开始训练过程了:

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)


trainer = pl.Trainer(
  logger=logger,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
 checkpoint_callback=checkpoint_callback,
  gpus=1,
  progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,

一旦我运行它,我就会得到这个错误:

~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
     75             if isinstance(checkpoint_callback, Callback):
     76                 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77             raise MisconfigurationException(error_msg)
     78         if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
     79             raise MisconfigurationException(

MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.

我该如何解决这个问题?

【问题讨论】:

【参考方案1】:

您可以在pl.Trainer的文档页面中查找checkpoint_callback参数的描述:

checkpoint_callback (bool) – 如果True,启用检查点。如果回调中没有用户定义的ModelCheckpoint,它将配置默认的ModelCheckpoint回调。

您不应将自定义 ModelCheckpoint 传递给此参数。我相信您要做的是在callbacks list 中同时传递EarlyStoppingModelCheckpoint

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min")

trainer = pl.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30)

【讨论】:

以上是关于通过模型检查点时 Pytorch 闪电出错的主要内容,如果未能解决你的问题,请参考以下文章

防止在训练模型时信息丢失 用于TensorFlowKeras和PyTorch的检查点教程

pytorch闪电模型的输出预测

TF2.0:翻译模型:恢复保存的模型时出错:检查点(根)中未解析的对象.optimizer.iter:属性

用 pytorch 闪电组织张量板图

无法从 Pytorch-Lightning 中的检查点加载模型

保存和加载 Pytorch 模型检查点以进行推理不起作用