通过模型检查点时 Pytorch 闪电出错
Posted
技术标签:
【中文标题】通过模型检查点时 Pytorch 闪电出错【英文标题】:Getting error with Pytorch lightning when passing model checkpoint 【发布时间】:2021-11-08 20:40:41 【问题描述】:我正在使用 Hugging 人脸模型训练一个多标签分类问题。我正在使用 Pytorch Lightning 来训练模型。
代码如下:
当损失没有改善时,提前停止触发
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
我们可以开始训练过程了:
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min"
)
trainer = pl.Trainer(
logger=logger,
callbacks=[early_stopping_callback],
max_epochs=N_EPOCHS,
checkpoint_callback=checkpoint_callback,
gpus=1,
progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,
一旦我运行它,我就会得到这个错误:
~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
75 if isinstance(checkpoint_callback, Callback):
76 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77 raise MisconfigurationException(error_msg)
78 if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
79 raise MisconfigurationException(
MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.
我该如何解决这个问题?
【问题讨论】:
【参考方案1】:您可以在pl.Trainer
的文档页面中查找checkpoint_callback
参数的描述:
checkpoint_callback
(bool) – 如果True
,启用检查点。如果回调中没有用户定义的ModelCheckpoint
,它将配置默认的ModelCheckpoint
回调。
您不应将自定义 ModelCheckpoint
传递给此参数。我相信您要做的是在callbacks
list 中同时传递EarlyStopping
和ModelCheckpoint
:
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min")
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30)
【讨论】:
以上是关于通过模型检查点时 Pytorch 闪电出错的主要内容,如果未能解决你的问题,请参考以下文章
防止在训练模型时信息丢失 用于TensorFlowKeras和PyTorch的检查点教程
TF2.0:翻译模型:恢复保存的模型时出错:检查点(根)中未解析的对象.optimizer.iter:属性