如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?

Posted

技术标签:

【中文标题】如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?【英文标题】:How to extract loss and accuracy from logger by each epoch in pytorch lightning? 【发布时间】:2021-11-15 11:48:29 【问题描述】:

我想提取所有数据来制作绘图,而不是使用 tensorboard。我的理解是,自从 tensorboard 绘制折线图以来,所有丢失和准确性的日志都存储在定义的目录中。

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

但是,我想知道如何从 pytorch 闪电中的记录器中提取所有日志。接下来是训练部分的代码示例。

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-lr-num_epoch', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)

我已经确认trainer.logger.log_dir 返回了似乎保存日志的目录,trainer.logger.log_metrics 返回了<bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>

trainer.logged_metrics 只返回最后一个 epoch 的日志,比如

'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185

你知道如何解决这个问题吗?

【问题讨论】:

【参考方案1】:

Lightning 不会自行存储所有日志。它所做的只是将它们logger 实例中,然后记录器决定做什么。

检索所有记录的指标的最佳方法是使用自定义回调:

class MetricTracker(Callback):

  def __init__(self):
    self.collection = []

  def on_validation_batch_end(trainer, module, outputs, ...):
    vacc = outputs['val_acc'] # you can access them here
    self.collection.append(vacc) # track them

  def on_validation_epoch_end(trainer, module):
    elogs = trainer.logged_metrics # access it here
    self.collection.append(elogs)
    # do whatever is needed

然后您可以从回调实例访问所有记录的内容

cb = MatricTracker()
Trainer(callbacks=[cb])

cb.collection # do you plotting and stuff

【讨论】:

【参考方案2】:

接受的答案在根本上没有错误,但不遵循 Pytorch-Lightning 的官方(当前)指南。

这里建议:https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger

建议写这样的类:

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment


class MyLogger(LightningLoggerBase):
    @property
    def name(self):
        return "MyLogger"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

    @property
    def version(self):
        # Return the experiment version, int or str.
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        pass

    @rank_zero_only
    def save(self):
        # Optional. Any code necessary to save logger data goes here
        # If you implement this, remember to call `super().save()`
        # at the start of the method (important for aggregation of metrics)
        super().save()

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

通过查看LightningLoggerBase 类的内部,可以看到一些可以被覆盖的功能建议。

这是我的一个简约记录器。它没有被高度优化,但将是一个很好的第一枪。如果我改进它,我会编辑。

import collections

from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only

class History_dict(LightningLoggerBase):
    def __init__(self):
        super().__init__()

        self.history = collections.defaultdict(list) # copy not necessary here  
        # The defaultdict in contrast will simply create any items that you try to access

    @property
    def name(self):
        return "Logger_custom_plot"

    @property
    def version(self):
        return "1.0"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

@rank_zero_only
def log_metrics(self, metrics, step):
    # metrics is a dictionary of metric names and values
    # your code to record metrics goes here
    for metric_name, metric_value in metrics.items():
        if metric_name != 'epoch':
            self.history[metric_name].append(metric_value)
        else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
            if (not len(self.history['epoch']) or    # len == 0:
                not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
                self.history['epoch'].append(metric_value)
            else:
                pass
    return

    def log_hyperparams(self, params):
        pass

【讨论】:

以上是关于如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?的主要内容,如果未能解决你的问题,请参考以下文章

权重和偏差扫描无法使用 pytorch 闪电导入模块

如何在忽略类中使用 pytorch 闪电精度?

如何在 PyTorch Lightning 中获得所有时期的逐步验证损失曲线

validation_epoch_end 与 DDP Pytorch 闪电

在 Pytorch tensorboard 中绘制多张图

pytorch闪电模型的输出预测