Pytorch实现Top1准确率和Top5准确率

Posted yqpy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch实现Top1准确率和Top5准确率相关的知识,希望对你有一定的参考价值。

之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,

具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:

def evaluteTop1(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        #correct += torch.eq(pred, y).sum().item()
    return correct / total

def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))
        y_resize = y.view(-1,1) _, pred
= logits.topk(maxk, 1, True, True) correct += torch.eq(pred, y_resize).sum().float().item() return correct / total

注意:y_resize = y.view(-1,1)是非常关键的一步,在correct的运算中,关键就是要pred和y_resize维度匹配,而原来的y是[128],128是batch大小;

pred的维度则是[128,10],假设这里是CIFAR10十分类;因此必须把y转化成[128,1]这种维度,但是不能直接是y.view(128,1),因为遍历整个数据集的时候,

最后一个batch大小并不是128,所以view()里面第一个size就设为-1未知,而确保第二个size是1就行

topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621

以上是关于Pytorch实现Top1准确率和Top5准确率的主要内容,如果未能解决你的问题,请参考以下文章

评估和计算 Top-N 准确度:Top 1 和 Top 5

使用 CNN 和 pytorch 计算每个类的准确率

为啥损失减少而准确率却没有增加? PyTorch

经典的卷积神经网络及其Pytorch代码实现

详解pytorch中的交叉熵损失函数nn.BCELoss()nn.BCELossWithLogits(),二分类任务如何定义损失函数,如何计算准确率如何预测

项目1:pytorch实现文本情感分析详细教程-准确度高达82%-98%