PyTorch Lightning 中的批量测试及其存在的问题

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch Lightning 中的批量测试及其存在的问题相关的知识,希望对你有一定的参考价值。

参考技术A

2022-1-5, Wed., 13:37 于鸢尾花基地
可以采用如下方式对之前保存的预训练模型进行批量测试:

然而,在上述循环中,通过 trainer.test 每执行一次测试,都只是执行了一个 epoch 的测试(也就是执行多次 ptl_module.test_step 和一次 ptl_module.test_epoch_end ),而不可能把 ckpt_list 中的多个预训练模型( checkpoint )当做多个 epoch ,多次执行 ptl_module.test_epoch_end 。

我们期望,对多个 checkpoint 的测试能像对多个 epoch 的训练一样简洁:

怎么做到?在训练过程中,要训练多少个 epoch 是由参数 max_epochs 来决定的;而在测试过程中,怎么办?PTL并非完整地保存了所有epoch的预训练模型。

由于在测试过程中对各 checkpoint 是独立测试的,如果要统计多个 checkpoint 的最优性能(如最大PSNR/SSIM),怎么办?这里的一个关键问题是如何保存每次测试得到的评估结果,好像PTL并未对此提供接口。

解决方案
PTL提供了“回调类(Callback)”(在 pytorch_lightning.callbacks 中),可以自定义一个回调类,并重载 on_test_epoch_end 方法,来监听 ptl_module.test_epoch_end 。
如何使用?只需要在定义 trainer 时,把该自定义的回调函数加入其参数 callbacks 即可: ptl.Trainer(callbacks=[MetricTracker()]) 。这里, MetricTracker 为自定义的回调类,具体如下:

评论: 由于 MetricTracker 具有与 Trainer 相同的生命周期,因此,在整个测试过程中, MetricTracker 能够维护一个最优的评估结果 optim_metrics 。

以上是关于PyTorch Lightning 中的批量测试及其存在的问题的主要内容,如果未能解决你的问题,请参考以下文章

pytorch lightning使用(简要介绍)

pytorch-lightning 中的正态分布采样

pytorch lightning 手写数字分类实例

无法从 Pytorch-Lightning 中的检查点加载模型

加载器的无效数据类型 - Pytorch Lightning DataModule

如何禁用 PyTorch-Lightning 记录器的日志记录?