PyTorchtorch.topk() 函数详解
Posted ZSYL
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorchtorch.topk() 函数详解相关的知识,希望对你有一定的参考价值。
1. 作用
取一个tensor的topk元素(降序后的前k个大小的元素值及索引)
2. 使用方法
dim=0
表示按照列求 topndim=1
表示按照行求 topn- 默认情况下,dim=1
3. 实例演示
任务一:
取top1(最大值):
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])
任务二:
按行取出topk,将小于topk的置为inf:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
top_k = 2 # 按行求出每一行的最大的前两个值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
print(pred)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[4],
[4],
[4],
[3]])
tensor([[0.4053],
[1.8823],
[1.7255],
[0.3849]])
tensor([[ True, False, True, True, False],
[ True, False, True, True, False],
[ True, True, False, True, False],
[ True, False, True, False, True]])
tensor([[ -inf, -0.3873, -inf, -inf, 0.4053],
[ -inf, 1.4164, -inf, -inf, 1.8823],
[ -inf, -inf, 1.2590, -inf, 1.7255],
[ -inf, 0.3041, -inf, 0.3849, -inf]])
任务三:
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
[3,4,5,1,1,1,1,1,1,1,1],
[7,8,9,1,1,1,1,1,1,1,1],
[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
# [6,5,4],
# [1,4,7],
# [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接输出topk,会得到两个东西,我们需要的是第二个indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.],
[ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]])
直接输出topk,会得到两个东西,我们需要的是第二个indices
torch.return_types.topk(
values=tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]]),
indices=tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]]))
topk[0]如下
tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]])
topk[1]如下
tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]])
'''
加油!
感谢!
努力!
以上是关于PyTorchtorch.topk() 函数详解的主要内容,如果未能解决你的问题,请参考以下文章