记录 PyTorch Lightning 的一个坑
Posted Xavier Jiezou
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了记录 PyTorch Lightning 的一个坑相关的知识,希望对你有一定的参考价值。
项目场景
PyTorch Lightning 对 PyTorch 做了进一步的封装,并继承了日志记录,分布式训练等工具,让我们能够把研究核心放在模型改进上而不是工程代码的编写。近期使用发现一个小问题,在此记录一下。
问题描述
模型训练的时候很正常,但验证的时候报错:
TypeError: validation_step() takes 3 positional arguments but 4 were given
并且,测试的时候也会遇到类似的问题。
原因分析
原来是我重写 LightningModule 的 validation_step
和 test_step
方法时没有指定 batch_idx
参数,虽然这个参数在方法中没有被使用,但是却会被隐式地调用。batch_idx
就是批数据的索引,例如打印训练进度条的时候肯定会被调用的。但如果不显式地指定,就是导致位置参数和关键字参数识别冲突,从而引发异常。
解决方案
这是我原来的代码:
def validation_step(self, batch):
pass
def test_step(self, batch):
pass
加上 batch_idx 参数就行了:
def validation_step(self, batch, batch_idx):
pass
def test_step(self, batch, batch_idx):
pass
引用参考
https://github.com/PyTorchLightning/pytorch-lightning/issues/1034
以上是关于记录 PyTorch Lightning 的一个坑的主要内容,如果未能解决你的问题,请参考以下文章
如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?
跨多个模型的 Pytorch Lightning Tensorboard 记录器
使用 Pytorch Lightning 时关闭 Hydra 的控制台日志记录
使用 Pytorch Lightning DDP 时记录事物的正确方法