Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'
Posted
技术标签:
【中文标题】Python:TypeError:val_dataloader()缺少1个必需的位置参数:\'self\'【英文标题】:Python: TypeError: val_dataloader() missing 1 required positional argument: 'self'Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self' 【发布时间】:2021-09-15 04:45:10 【问题描述】:我正在使用 PyTorch Lightning 进行图像分类任务。但是我在实现它时得到了TypeError
。我已经创建了数据模块和模型,如 PyTorch Lightning 示例中所示。我使用的模型是VGG16
,带有Batch Normalization。
在FruitsDataModule
中,我只得到val_dataloader
的错误,而不是train_dataloader
的错误,这令人困惑,因为这两个函数都在用不同的数据做完全相同的事情。
相关代码如下所示。
数据模块
class FruitsDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.transform = transforms.Compose(
[
transforms.ToTensor()
]
)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
full_train_dataset = datasets.ImageFolder(
root = config.TRAIN_DATA_PATH,
transform = self.transform
)
train_dataset, val_dataset = train_test_split(
full_train_dataset,
test_size=0.33,
random_state = 42
)
if stage == 'test' or stage is None:
test_dataset = datasets.ImageFolder(
root = config.TEST_DATA_PATH,
transform = self.transform
)
def train_dataloader(self):
return DataLoader(
train_dataset,
batch_size = config.BATCH_SIZE,
shuffle = True
)
def val_dataloader(self):
return DataLoader(
val_dataset,
batch_size = config.BATCH_SIZE,
)
def test_dataloader(self):
return DataLoader(
test_dataset,
batch_size = config.BATCH_SIZE,
)
型号
class VGGModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.vgg16_bn(pretrained=True)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
x = self.model(x)
return x
def step(self, batch):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = torch.argmax(logits, dim=1)
return loss, preds, y
def training_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log train metrics
acc = roc_auc_score(preds, targets)
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return "loss": loss, "preds": preds, "targets": targets
def validation_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log val metrics
acc = roc_auc_score(preds, targets)
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return "loss": loss, "preds": preds, "targets": targets
def test_step(self, batch, batch_idx):
loss, preds, targets = self.step(batch)
# log test metrics
acc = roc_auc_score(preds, targets)
self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", acc, on_step=False, on_epoch=True)
return "loss": loss, "preds": preds, "targets": targets
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = config.LEARNING_RATE)
return optimizer
培训
model = VGGModel()
trainer = pl.Trainer(
max_epochs=1,
gpus=[0],
precision=32,
progress_bar_refresh_rate=20
)
trainer.fit(model, datamodule = FruitsDataModule)
错误日志
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-5990df7ecb15> in <module>
8 )
9
---> 10 trainer.fit(model, datamodule = FruitsDataModule)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
497
498 # dispath `start_training` or `start_testing` or `start_predicting`
--> 499 self.dispatch()
500
501 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
544
545 else:
--> 546 self.accelerator.start_training(self)
547
548 def train_or_test_or_predict(self):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
71
72 def start_training(self, trainer):
---> 73 self.training_type_plugin.start_training(trainer)
74
75 def start_testing(self, trainer):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
112 def start_training(self, trainer: 'Trainer') -> None:
113 # double dispatch to initiate the training loop
--> 114 self._results = trainer.run_train()
115
116 def start_testing(self, trainer: 'Trainer') -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
605 self.progress_bar_callback.disable()
606
--> 607 self.run_sanity_check(self.lightning_module)
608
609 # set stage for logging
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
852 # to make sure program won't crash during val
853 if should_sanity_check:
--> 854 self.reset_val_dataloader(ref_model)
855 self.num_sanity_val_batches = [
856 min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in reset_val_dataloader(self, model)
362 has_step = is_overridden('validation_step', model)
363 if has_loader and has_step:
--> 364 self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
365
366 def reset_test_dataloader(self, model) -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in _reset_eval_dataloader(self, model, mode)
276 # always get the loaders first so we can count how many there are
277 loader_name = f'mode_dataloader'
--> 278 dataloaders = self.request_dataloader(getattr(model, loader_name))
279
280 if not isinstance(dataloaders, list):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py in request_dataloader(self, dataloader_fx)
396 The dataloader
397 """
--> 398 dataloader = dataloader_fx()
399 dataloader = self._flatten_dl_only(dataloader)
400
TypeError: val_dataloader() missing 1 required positional argument: 'self'
如何解决这个问题?
【问题讨论】:
错误与它指向的代码行不匹配——是否有更深的堆栈跟踪?dataloader_fx()
是做什么的?您发布的代码都没有调用val_dataloader()
,但错误消息表明某些东西试图将其作为类方法而不是实例方法来调用。
@Samwise 我通过添加 Training
部分更新了我的问题,并显示了整个错误日志。 dataloader_fx()
我没有实现,是pytorch闪电内部实现
啊——我不熟悉 pytorch 的 API,但是您的数据模块是否应该将 val_dataloader
实现为静态方法? (把@staticmethod
放在前面,去掉self
参数。)
@Samwise 也许你可以看看this。它在 mnist 数据集上有类似的实现。
【参考方案1】:
datamodule
应该是对象,而不是类。因此,这个
trainer.fit(model, datamodule=FruitsDataModule)
应该是
trainer.fit(model, datamodule=FruitsDataModule())
【讨论】:
以上是关于Python:TypeError:val_dataloader()缺少1个必需的位置参数:'self'的主要内容,如果未能解决你的问题,请参考以下文章
python fbprophet错误,TypeError:'module'对象不可调用
TypeError:“NoneType”对象在 Python 中不可迭代