无法从 Pytorch-Lightning 中的检查点加载模型

Posted

技术标签:

【中文标题】无法从 Pytorch-Lightning 中的检查点加载模型【英文标题】:Unable to load model from checkpoint in Pytorch-Lightning 【发布时间】:2021-01-15 19:06:12 【问题描述】:

我正在使用 Pytorch Lightning 中的 U-Net。我能够成功训练模型,但训练后当我尝试从检查点加载模型时出现此错误:

完整的回溯:

Traceback (most recent call last):
  File "src/train.py", line 269, in <module>
    main(sys.argv[1:])
  File "src/train.py", line 263, in main
    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "src/train.py", line 162, in __init__
    self.inc = double_conv(self.n_channels, 64)
  File "src/train.py", line 122, in double_conv
    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 406, in __init__
    super(Conv2d, self).__init__(
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 50, in __init__
    if in_channels % groups != 0:
TypeError: unsupported operand type(s) for %: 'dict' and 'int'

我尝试浏览 github 问题和论坛,但无法弄清楚问题所在。请帮忙。

这是我的模型的代码和检查点加载步骤: 型号:

class Unet(pl.LightningModule):
    def __init__(self, n_channels, n_classes=5):
        super(Unet, self).__init__()
        # self.hparams = hparams

        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = True
        self.logger = WandbLogger(name="Adam", project="pytorchlightning")

        def double_conv(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

        def down(in_channels, out_channels):
            return nn.Sequential(
                nn.MaxPool2d(2), double_conv(in_channels, out_channels)
            )

        class up(nn.Module):
            def __init__(self, in_channels, out_channels, bilinear=False):
                super().__init__()

                if bilinear:
                    self.up = nn.Upsample(
                        scale_factor=2, mode="bilinear", align_corners=True
                    )
                else:
                    self.up = nn.ConvTranspose2d(
                        in_channels // 2, in_channels // 2, kernel_size=2, stride=2
                    )

                self.conv = double_conv(in_channels, out_channels)

            def forward(self, x1, x2):
                x1 = self.up(x1)
                # [?, C, H, W]
                diffY = x2.size()[2] - x1.size()[2]
                diffX = x2.size()[3] - x1.size()[3]

                x1 = F.pad(
                    x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
                )
                x = torch.cat([x2, x1], dim=1)
                return self.conv(x)

        self.inc = double_conv(self.n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)

        x2 = self.down1(x1)

        x3 = self.down2(x2)

        x4 = self.down3(x3)

        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return self.out(x)

    def training_step(self, batch, batch_nb):
        x, y = batch

        y_hat = self.forward(x)
        loss = self.MSE(y_hat, y)

        # wandb_logger.log_metrics("loss":loss)
        return "loss": loss

    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.logger.log_metrics("train_loss": avg_train_loss)
        return "average_loss": avg_train_loss

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.MSE(y_hat, y)
        return "test_loss": loss, "pred": y_hat

    def test_end(self, outputs):

        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()

        return "avg_test_loss": avg_loss

    def MSE(self, logits, labels):

        return torch.mean((logits - labels) ** 2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1, weight_decay=1e-8)

主要功能:

def main(expconfig):
    # Define checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        filepath="/home/africa_wikilimo/data/model_checkpoint/",
        save_top_k=1,
        verbose=True,
        monitor="loss",
        mode="min",
        prefix="",
    )

    # Initialise datasets
    print("Initializing Climate Dataset....")
    clima_train = Clima_Dataset(expconfig[0])

    # Initialise dataloaders
    print("Initializing train_loader....")
    train_dataloader = DataLoader(clima_train, batch_size=2, num_workers=4)

    # Initialise model and trainer
    print("Initializing model...")
    model = Unet(n_channels=9, n_classes=5)
    print("Initializing Trainer....")
    if torch.cuda.is_available():

        model.cuda()

        trainer = pl.Trainer(
            max_epochs=1,
            gpus=1,
            checkpoint_callback=checkpoint_callback,
            early_stop_callback=None,
        )
    else:

        trainer = pl.Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)
    
    trainer.fit(model, train_dataloader=train_dataloader)
    print(checkpoint_callback.best_model_path)
    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)

【问题讨论】:

【参考方案1】:

原因

发生这种情况是因为您的模型无法从检查点加载超参数(n_channels, n_classes=5),因为您没有明确保存它们。

修复

您可以通过在 Unet 类的 init 方法中使用 self.save_hyperparameters('n_channels', 'n_classes')method 来解决它。 有关使用此方法的更多详细信息,请参阅PyTorch Lightning hyperparams-docs。使用 save_hyperparameters 可以将所选参数与检查点一起保存在 hparams.yaml 中。

感谢@Adrian Wälchli (awaelchli) 来自 PyTorch Lightning 核心贡献者团队,当我遇到同样的问题时,他们提出了这个修复建议。

【讨论】:

链接失效

以上是关于无法从 Pytorch-Lightning 中的检查点加载模型的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?

如何禁用 PyTorch-Lightning 记录器的日志记录?

使用 pytorch-lightning 实现 Network in Network CNN 模型

权重和偏差扫描无法使用 pytorch 闪电导入模块

pytorch-lightning入门—— 初了解

PyTorch-lightning 模型在第一个 epoch 后内存不足