PyTorch-lightning 模型在第一个 epoch 后内存不足

Posted

技术标签:

【中文标题】PyTorch-lightning 模型在第一个 epoch 后内存不足【英文标题】:PyTorch-lightning models running out of Memory after 1st epoch 【发布时间】:2021-09-22 02:18:06 【问题描述】:

我在 PyTorch 上看到了一个 Kaggle 内核,并使用相同的 img_size、batch_size 等运行它,并创建了另一个具有完全相同值的 PyTorch-lightning 内核,但我的闪电模型在大约 1.5 个 epoch 后内存不足(每个 epoch 包含 8750步骤)在第一折,而原生 PyTorch 模型运行整个 5 折。有没有办法改进代码或释放内存?我本可以尝试删除模型或进行一些垃圾收集,但如果它甚至没有完成第一次折叠,我就无法删除模型和东西。

def run_fold(fold):
    
    df_train = train[train['fold'] != fold]
    df_valid = train[train['fold'] == fold]
    
    train_dataset = G2NetDataset(df_train, get_train_aug())
    valid_dataset = G2NetDataset(df_valid, get_test_aug())
    
    train_dl = DataLoader(train_dataset,
                          batch_size = config.batch_size,
                          num_workers = config.num_workers,
                          shuffle = True,
                          drop_last = True,
                          pin_memory = True)
    
    valid_dl = DataLoader(valid_dataset,
                         batch_size = config.batch_size,
                         num_workers = config.num_workers,
                         shuffle = False,
                         drop_last = False,
                         pin_memory = True)
    
    
    model = Classifier()
    logger = pl.loggers.WandbLogger(project='G2Net', name=f'fold: fold')
    
    trainer = pl.Trainer(gpus = 1, 
                         max_epochs = config.epochs,
                         fast_dev_run = config.debug,
                         logger = logger,
                         log_every_n_steps=10)
    
    trainer.fit(model, train_dl, valid_dl)
    result = trainer.test(test_dataloaders = valid_dl)
    wandb.run.finish() 
    return result

def main():   
    if config.train:
        results = []
        for fold in range(config.n_fold):
            result = run_fold(fold)
            results.append(result)      
    return results

results = main()

【问题讨论】:

也许要尝试的一件事是在计算指标时将 .detach() 用于当前循环中不需要的任何张量。这样,您就不会无缘无故地存储张量及其整个图表。你试过这个吗? 但我使用的是 Pytorch-Lightning,因此无法在任何给定点分离,因为闪电管理后台中的所有内容。 嗨,你能附上你的火车功能的sn-p吗?这样调试起来会更容易。 【参考方案1】:

如果不查看您的模型类,我不能说太多,但我遇到的几个可能的问题是日志记录的度量和损失评估。 例如,像

pl.metrics.Accuracy(compute_on_step=False)

需要和显式调用 .compute()

def training_epoch_end(self, outputs):
    loss = sum([out['loss'] for out in outputs])/len(outputs)
    self.log_dict('train_loss' : loss.detach(), 
               'train_accuracy' : self.train_metric.compute())

在纪元结束时。

【讨论】:

以上是关于PyTorch-lightning 模型在第一个 epoch 后内存不足的主要内容,如果未能解决你的问题,请参考以下文章

使用 pytorch-lightning 实现 Network in Network CNN 模型

使用 pytorch-lightning 进行简单预测的示例

如何禁用 PyTorch-Lightning 记录器的日志记录?

Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?

pytorch-lightning 中的正态分布采样

pytorch-lightning入门—— 初了解