PyTorch Lightning 是不是在整个时期内平均指标?
Posted
技术标签:
【中文标题】PyTorch Lightning 是不是在整个时期内平均指标?【英文标题】:Does the PyTorch Lightning average metrics over the whole epoch?PyTorch Lightning 是否在整个时期内平均指标? 【发布时间】:2021-06-05 13:08:06 【问题描述】:我正在查看PyTorch-Lightning
官方文档https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html上提供的示例。
这里的损失和度量是根据具体批次计算的。但是当记录一个特定批次的准确性时,它可能相当小且不具有代表性,而是对所有时期的平均值不感兴趣。我是否理解正确,有一些代码对所有批次执行平均,并通过了 epoch?
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict('val_acc': acc, 'val_loss': loss)
return result
def test_step(self, batch, batch_idx):
result = self.validation_step(batch, batch_idx)
result.rename_keys('val_acc': 'test_acc', 'val_loss': 'test_loss')
return result
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
【问题讨论】:
好问题 - 我来找同样的东西。我怀疑它是每批次的。 【参考方案1】:如果您想对整个时期的指标进行平均,您需要告诉 LightningModule
您已经子类化了这样做。有几种不同的方法可以做到这一点,例如:
-
使用
on_epoch=True
调用result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
as shown in the docs,以便在整个时期内平均训练损失。即:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return result
-
或者,您可以在
LightningModule
本身上调用 log
方法:self.log("train_loss", loss, on_epoch=True, sync_dist=True)
(可以选择传递 sync_dist=True
以跨加速器减少)。
您需要在 validation_step
中执行类似的操作来获取聚合的 val-set 指标或在 validation_epoch_end
方法中自己实现聚合。
【讨论】:
以上是关于PyTorch Lightning 是不是在整个时期内平均指标?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?
如何从Pytorch 到 Pytorch Lightning | 简要介绍