记录 PyTorch Lightning 的一个坑

Posted Xavier Jiezou

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了记录 PyTorch Lightning 的一个坑相关的知识,希望对你有一定的参考价值。

项目场景

PyTorch Lightning 对 PyTorch 做了进一步的封装,并继承了日志记录,分布式训练等工具,让我们能够把研究核心放在模型改进上而不是工程代码的编写。近期使用发现一个小问题,在此记录一下。

问题描述

模型训练的时候很正常,但验证的时候报错:

TypeError: validation_step() takes 3 positional arguments but 4 were given

并且,测试的时候也会遇到类似的问题。

原因分析

原来是我重写 LightningModulevalidation_steptest_step 方法时没有指定 batch_idx 参数,虽然这个参数在方法中没有被使用,但是却会被隐式地调用。batch_idx 就是批数据的索引,例如打印训练进度条的时候肯定会被调用的。但如果不显式地指定,就是导致位置参数和关键字参数识别冲突,从而引发异常。

解决方案

这是我原来的代码:

def validation_step(self, batch):
	pass

def test_step(self, batch):
    pass

加上 batch_idx 参数就行了:

def validation_step(self, batch, batch_idx):
	pass

def test_step(self, batch, batch_idx):
    pass

引用参考

https://github.com/PyTorchLightning/pytorch-lightning/issues/1034

以上是关于记录 PyTorch Lightning 的一个坑的主要内容,如果未能解决你的问题,请参考以下文章

如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?

跨多个模型的 Pytorch Lightning Tensorboard 记录器

使用 Pytorch Lightning 时关闭 Hydra 的控制台日志记录

使用 Pytorch Lightning DDP 时记录事物的正确方法

如何在 PyTorch Lightning 中获得所有时期的逐步验证损失曲线

PyTorch-lightning 模型在第一个 epoch 后内存不足