Pytorch实现Top1准确率和Top5准确率
Posted yqpy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch实现Top1准确率和Top5准确率相关的知识,希望对你有一定的参考价值。
之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,
具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:
def evaluteTop1(model, loader): model.eval() correct = 0 total = len(loader.dataset) for x,y in loader: x,y = x.to(device), y.to(device) with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1) correct += torch.eq(pred, y).sum().float().item() #correct += torch.eq(pred, y).sum().item() return correct / total def evaluteTop5(model, loader): model.eval() correct = 0 total = len(loader.dataset) for x, y in loader: x,y = x.to(device),y.to(device) with torch.no_grad(): logits = model(x) maxk = max((1,5))
y_resize = y.view(-1,1) _, pred = logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total
注意:y_resize = y.view(-1,1)是非常关键的一步,在correct的运算中,关键就是要pred和y_resize维度匹配,而原来的y是[128],128是batch大小;
pred的维度则是[128,10],假设这里是CIFAR10十分类;因此必须把y转化成[128,1]这种维度,但是不能直接是y.view(128,1),因为遍历整个数据集的时候,
最后一个batch大小并不是128,所以view()里面第一个size就设为-1未知,而确保第二个size是1就行
topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621
以上是关于Pytorch实现Top1准确率和Top5准确率的主要内容,如果未能解决你的问题,请参考以下文章
详解pytorch中的交叉熵损失函数nn.BCELoss()nn.BCELossWithLogits(),二分类任务如何定义损失函数,如何计算准确率如何预测