Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'

Posted

技术标签:

【中文标题】Python:TypeError:val_dataloader()缺少1个必需的位置参数:\'self\'【英文标题】:Python: TypeError: val_dataloader() missing 1 required positional argument: 'self'Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self' 【发布时间】:2021-09-15 04:45:10 【问题描述】:

我正在使用 PyTorch Lightning 进行图像分类任务。但是我在实现它时得到了TypeError。我已经创建了数据模块和模型,如 PyTorch Lightning 示例中所示。我使用的模型是VGG16,带有Batch Normalization。

FruitsDataModule 中,我只得到val_dataloader 的错误,而不是train_dataloader 的错误,这令人困惑,因为这两个函数都在用不同的数据做完全相同的事情。

相关代码如下所示。

数据模块

class FruitsDataModule(pl.LightningDataModule):
    
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose(
            [
                transforms.ToTensor()
            ]
        )
        
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:

            full_train_dataset = datasets.ImageFolder(
                root = config.TRAIN_DATA_PATH,
                transform = self.transform
            )
            
            train_dataset, val_dataset = train_test_split(
                full_train_dataset, 
                test_size=0.33,
                random_state = 42
            )
            
        if stage == 'test' or stage is None:

            test_dataset = datasets.ImageFolder(
                root = config.TEST_DATA_PATH,
                transform = self.transform
            )
            
    def train_dataloader(self):
        return DataLoader(
            train_dataset,
            batch_size = config.BATCH_SIZE,
            shuffle = True
        )
        
    def val_dataloader(self):
        return DataLoader(
            val_dataset,
            batch_size = config.BATCH_SIZE,
        )
    
    def test_dataloader(self):
        return DataLoader(
            test_dataset,
            batch_size = config.BATCH_SIZE,
        )

型号

class VGGModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        
        self.model = models.vgg16_bn(pretrained=True)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y
        
    
    def training_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)
        
        # log train metrics
        acc = roc_auc_score(preds, targets)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return "loss": loss, "preds": preds, "targets": targets
    
    
   
    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log val metrics
        acc = roc_auc_score(preds, targets)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return "loss": loss, "preds": preds, "targets": targets
    
    
    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log test metrics
        acc = roc_auc_score(preds, targets)
        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", acc, on_step=False, on_epoch=True)

        return "loss": loss, "preds": preds, "targets": targets
    
   
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = config.LEARNING_RATE)
        return optimizer

培训

model = VGGModel()

trainer = pl.Trainer(
    max_epochs=1,
    gpus=[0],
    precision=32,
    progress_bar_refresh_rate=20
)

trainer.fit(model, datamodule = FruitsDataModule)

错误日志

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-5990df7ecb15> in <module>
      8 )
      9 
---> 10 trainer.fit(model, datamodule = FruitsDataModule)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    497 
    498         # dispath `start_training` or `start_testing` or `start_predicting`
--> 499         self.dispatch()
    500 
    501         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    544 
    545         else:
--> 546             self.accelerator.start_training(self)
    547 
    548     def train_or_test_or_predict(self):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     71 
     72     def start_training(self, trainer):
---> 73         self.training_type_plugin.start_training(trainer)
     74 
     75     def start_testing(self, trainer):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    112     def start_training(self, trainer: 'Trainer') -> None:
    113         # double dispatch to initiate the training loop
--> 114         self._results = trainer.run_train()
    115 
    116     def start_testing(self, trainer: 'Trainer') -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    605             self.progress_bar_callback.disable()
    606 
--> 607         self.run_sanity_check(self.lightning_module)
    608 
    609         # set stage for logging

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
    852         # to make sure program won't crash during val
    853         if should_sanity_check:
--> 854             self.reset_val_dataloader(ref_model)
    855             self.num_sanity_val_batches = [
    856                 min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in reset_val_dataloader(self, model)
    362         has_step = is_overridden('validation_step', model)
    363         if has_loader and has_step:
--> 364             self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
    365 
    366     def reset_test_dataloader(self, model) -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in _reset_eval_dataloader(self, model, mode)
    276         # always get the loaders first so we can count how many there are
    277         loader_name = f'mode_dataloader'
--> 278         dataloaders = self.request_dataloader(getattr(model, loader_name))
    279 
    280         if not isinstance(dataloaders, list):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in request_dataloader(self, dataloader_fx)
    396             The dataloader
    397         """
--> 398         dataloader = dataloader_fx()
    399         dataloader = self._flatten_dl_only(dataloader)
    400 

TypeError: val_dataloader() missing 1 required positional argument: 'self'

如何解决这个问题?

【问题讨论】:

错误与它指向的代码行不匹配——是否有更深的堆栈跟踪? dataloader_fx() 是做什么的?您发布的代码都没有调用val_dataloader(),但错误消息表明某些东西试图将其作为类方法而不是实例方法来调用。 @Samwise 我通过添加 Training 部分更新了我的问题,并显示了整个错误日志。 dataloader_fx()我没有实现,是pytorch闪电内部实现 啊——我不熟悉 pytorch 的 API,但是您的数据模块是否应该将 val_dataloader 实现为静态方法? (把@staticmethod放在前面,去掉self参数。) @Samwise 也许你可以看看this。它在 mnist 数据集上有类似的实现。 【参考方案1】:

datamodule 应该是对象,而不是类。因此,这个

trainer.fit(model, datamodule=FruitsDataModule)

应该是

trainer.fit(model, datamodule=FruitsDataModule())

【讨论】:

以上是关于Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'的主要内容,如果未能解决你的问题,请参考以下文章

Python - TypeError:需要可迭代参数

TypeError:实例Python之间不支持'<'

python fbprophet错误,TypeError:'module'对象不可调用

TypeError:“NoneType”对象在 Python 中不可迭代

“TypeError:'WebElement'对象不可迭代”错误代码python爬取

TypeError:'int'对象不可迭代,使用 Python 3 [关闭]