Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”

Posted

技术标签:

【中文标题】Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”【英文标题】:Pytorch lightning: 'CIFAR10DataModule' object has no attribute 'train_loader' 【发布时间】:2021-10-28 08:40:44 【问题描述】:

你能告诉我为什么我无法导入 CUFAR10DataModule() 吗?

一开始,我在 GoogleColab 上运行代码,

from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()

然后,执行代码进行确认

from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)

for epoch in range(10):
  for batch in dm.train_loader:
    x, y = batch
    with torch.no_grad():
      features = backbone(x)

    preds = finetune_layer(features)
    loss = cross_entropy(preds, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())

但是,运行代码后返回消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'

当运行代码确认dm时,

for batch in dm.train_dataloader:
  x, y = batch
  print(x.shape, y.shape)
  break

错误为TypeError: 'method' object is not iterable

代码看起来和一个例子一样,但是不知道为什么会产生这样的错误?

【问题讨论】:

【参考方案1】:

你的代码有两个问题:

首先,获取底层 PyTorch 数据加载器的方式是 dm.train_dataloader() 而不是 dm.train_loader。它是一个函数,而不是一个属性

for batch in dm.train_dataloader():
    x, y = batch
    ...

其次,由于您尝试使用LightningDataModule 而不使用Trainer,因此您需要手动调用

dm.prepare_data()
dm.setup()

.. 以便通过.train_dataloader() 使用数据加载器。

【讨论】:

以上是关于Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”的主要内容,如果未能解决你的问题,请参考以下文章

用 pytorch 闪电组织张量板图

权重和偏差扫描无法使用 pytorch 闪电导入模块

使用 pytorch 闪电的不同测试结果

通过模型检查点时 Pytorch 闪电出错

训练步骤未在 pytorch 闪电中执行

如何将 pytorch 闪电分析器与 tensorboard 集成?