如何在忽略类中使用 pytorch 闪电精度?

Posted

技术标签:

【中文标题】如何在忽略类中使用 pytorch 闪电精度?【英文标题】:How to use pytorch lightning Accuracy with ignore class? 【发布时间】:2021-07-04 05:05:13 【问题描述】:

我有一些使用 CrossEntropyLoss 和忽略类的训练管道。

模型输出 log_probs 形状 (150, 3) - 表示 3 个可能的类,每批 150 个。

label_batch 的形状为 150torch.max(label_batch) == tensor(3, device='cuda:0'),这意味着有一个标记为 3 的额外类,即忽略类。

损失处理得很好:

self._criterion = nn.CrossEntropyLoss(
    reduction='mean',
    ignore_index=3
)

但是准确度指标认为3 类是有效的并且给出了非常错误的结果:

self.train_acc = pl.metrics.Accuracy()

由于3 标签导致self.train_acc.update(log_probs, label_batch) 的错误结果应被忽略。


如何正确使用 pl.metrics.Accuracy() 和忽略类?

【问题讨论】:

【参考方案1】:

复制github论坛https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890讨论帖的回复


准确度指标目前不支持它,但我们有一个开放的 PR 用于实现该精确功能 PyTorchLightning/metrics#155

目前您可以改为计算混淆矩阵,然后基于此忽略一些类别(请记住,真正的正分类/正确分类位于混淆矩阵的对角线上):

ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()

【讨论】:

谢谢!这是一个很好的答案,但让我想知道为什么所有指标都不是基于ConfusionMatrix?只有更新应该去矩阵,然后它应该有一些属性,如accuracyf1mcc。为什么它们不同? 另外我最终做的(在阅读本文之前)是ignored_inds = label_batch == self._ignore_class_numacc.update(log_probs[~ignored_inds, :], label_batch[~ignored_inds])。有意义吗?

以上是关于如何在忽略类中使用 pytorch 闪电精度?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?

如何使用 Pytorch 中的截断反向传播(闪电)在很长的序列上运行 LSTM?

pytorch闪电模型的输出预测

如何使用 PyTorch 在语义分割中获得前 k 个精度?

计算精度问题

使用 pytorch 闪电的不同测试结果