PyTorch Lightning 中的批量测试及其存在的问题
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch Lightning 中的批量测试及其存在的问题相关的知识,希望对你有一定的参考价值。
参考技术A 2022-1-5, Wed., 13:37 于鸢尾花基地
可以采用如下方式对之前保存的预训练模型进行批量测试:
然而,在上述循环中,通过 trainer.test 每执行一次测试,都只是执行了一个 epoch 的测试(也就是执行多次 ptl_module.test_step 和一次 ptl_module.test_epoch_end ),而不可能把 ckpt_list 中的多个预训练模型( checkpoint )当做多个 epoch ,多次执行 ptl_module.test_epoch_end 。
我们期望,对多个 checkpoint 的测试能像对多个 epoch 的训练一样简洁:
怎么做到?在训练过程中,要训练多少个 epoch 是由参数 max_epochs 来决定的;而在测试过程中,怎么办?PTL并非完整地保存了所有epoch的预训练模型。
由于在测试过程中对各 checkpoint 是独立测试的,如果要统计多个 checkpoint 的最优性能(如最大PSNR/SSIM),怎么办?这里的一个关键问题是如何保存每次测试得到的评估结果,好像PTL并未对此提供接口。
解决方案
PTL提供了“回调类(Callback)”(在 pytorch_lightning.callbacks 中),可以自定义一个回调类,并重载 on_test_epoch_end 方法,来监听 ptl_module.test_epoch_end 。
如何使用?只需要在定义 trainer 时,把该自定义的回调函数加入其参数 callbacks 即可: ptl.Trainer(callbacks=[MetricTracker()]) 。这里, MetricTracker 为自定义的回调类,具体如下:
评论: 由于 MetricTracker 具有与 Trainer 相同的生命周期,因此,在整个测试过程中, MetricTracker 能够维护一个最优的评估结果 optim_metrics 。
以上是关于PyTorch Lightning 中的批量测试及其存在的问题的主要内容,如果未能解决你的问题,请参考以下文章
无法从 Pytorch-Lightning 中的检查点加载模型