如何使用 PyTorch 在语义分割中获得前 k 个精度?
Posted
技术标签:
【中文标题】如何使用 PyTorch 在语义分割中获得前 k 个精度?【英文标题】:How to get top k accuracy in semantic segmentation using PyTorch? 【发布时间】:2020-04-15 21:40:42 【问题描述】:如何计算语义分割中的 top k 准确率?在分类中,我们可以将topk准确率计算为:
correct = output.eq(gt.view(1, -1).expand_as(output))
【问题讨论】:
对于正常分类,您可以检查:discuss.pytorch.org/t/imagenet-example-accuracy-calculation/… 这有帮助吗?细分有什么不同? 【参考方案1】:您正在寻找 torch.topk
函数,该函数可计算维度上的前 k 个值。torch.topk
的第二个输出是“arg top k”:前值的 k 个索引。
这是在语义分割上下文中的使用方法:
假设您有形状为b
-h
-w
(dtype=torch.int64
) 的基本事实预测张量y
。
您的模型预测形状 b
-c
-h
-w
的每像素类 logits
,c
是类的数量(包括“背景”)。这些 logits 是 之前 softmax
函数将它们转换为类概率的“原始”预测。
由于我们只查看顶部的k
,因此预测是“原始”还是“概率”并不重要。
# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()
请注意,此方法不考虑“忽略”像素。这可以通过对上述代码稍作修改来完成:
valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
【讨论】:
【参考方案2】:假设您的输出是根据您的课程列表排序的一系列分数labels
:
import torch
scores, indices = torch.topk(output, k)
correct = labels[indices]
【讨论】:
以上是关于如何使用 PyTorch 在语义分割中获得前 k 个精度?的主要内容,如果未能解决你的问题,请参考以下文章
TorchSeg—基于PyTorch的快速模块化语义分割开源库