Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”
Posted
技术标签:
【中文标题】Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”【英文标题】:Pytorch lightning: 'CIFAR10DataModule' object has no attribute 'train_loader' 【发布时间】:2021-10-28 08:40:44 【问题描述】:你能告诉我为什么我无法导入 CUFAR10DataModule() 吗?
一开始,我在 GoogleColab 上运行代码,
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()
然后,执行代码进行确认
from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)
for epoch in range(10):
for batch in dm.train_loader:
x, y = batch
with torch.no_grad():
features = backbone(x)
preds = finetune_layer(features)
loss = cross_entropy(preds, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss.item())
但是,运行代码后返回消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'
。
当运行代码确认dm
时,
for batch in dm.train_dataloader:
x, y = batch
print(x.shape, y.shape)
break
错误为TypeError: 'method' object is not iterable
。
代码看起来和一个例子一样,但是不知道为什么会产生这样的错误?
【问题讨论】:
【参考方案1】:你的代码有两个问题:
首先,获取底层 PyTorch 数据加载器的方式是 dm.train_dataloader()
而不是 dm.train_loader
。它是一个函数,而不是一个属性。
for batch in dm.train_dataloader():
x, y = batch
...
其次,由于您尝试使用LightningDataModule
而不使用Trainer
,因此您需要手动调用
dm.prepare_data()
dm.setup()
.. 以便通过.train_dataloader()
使用数据加载器。
【讨论】:
以上是关于Pytorch 闪电:“CIFAR10DataModule”对象没有属性“train_loader”的主要内容,如果未能解决你的问题,请参考以下文章