如何在忽略类中使用 pytorch 闪电精度?
Posted
技术标签:
【中文标题】如何在忽略类中使用 pytorch 闪电精度?【英文标题】:How to use pytorch lightning Accuracy with ignore class? 【发布时间】:2021-07-04 05:05:13 【问题描述】:我有一些使用 CrossEntropyLoss
和忽略类的训练管道。
模型输出 log_probs
形状 (150, 3)
- 表示 3 个可能的类,每批 150 个。
label_batch
的形状为 150
,torch.max(label_batch)
== tensor(3, device='cuda:0')
,这意味着有一个标记为 3
的额外类,即忽略类。
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',
ignore_index=3
)
但是准确度指标认为3
类是有效的并且给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
由于3
标签导致self.train_acc.update(log_probs, label_batch)
的错误结果应被忽略。
如何正确使用 pl.metrics.Accuracy()
和忽略类?
【问题讨论】:
【参考方案1】:复制github论坛https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890讨论帖的回复
准确度指标目前不支持它,但我们有一个开放的 PR 用于实现该精确功能 PyTorchLightning/metrics#155
目前您可以改为计算混淆矩阵,然后基于此忽略一些类别(请记住,真正的正分类/正确分类位于混淆矩阵的对角线上):
ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()
【讨论】:
谢谢!这是一个很好的答案,但让我想知道为什么所有指标都不是基于ConfusionMatrix
?只有更新应该去矩阵,然后它应该有一些属性,如accuracy
、f1
、mcc
。为什么它们不同?
另外我最终做的(在阅读本文之前)是ignored_inds = label_batch == self._ignore_class_num
和acc.update(log_probs[~ignored_inds, :], label_batch[~ignored_inds])
。有意义吗?以上是关于如何在忽略类中使用 pytorch 闪电精度?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?