pytorch :使用两次sort函数(排序)找出矩阵每个元素在升序或降序排列中的位置

Posted 慢行厚积

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch :使用两次sort函数(排序)找出矩阵每个元素在升序或降序排列中的位置相关的知识,希望对你有一定的参考价值。

在SSD的代码中经常有见到如下的操作:

        _, idx = flt[:, :, 0].sort(1, descending=True)#大小为[batch size, num_classes*top_k]
        _, rank = idx.sort(1)#再对索引升序排列,得到其索引作为排名rank

其作用是什么呢?举个例子:

import torch
a = torch.randn(3,4)
print(a)
print()

i, idx = a.sort(dim=1, descending=True)
print(i)
print(idx)
print()

j, rank = idx.sort(dim=1)
print(rank)

返回:

tensor([[ 2.3326,  0.0275, -0.0799,  0.4156],
        [-2.2066,  1.7997, -2.2767,  0.4704],
        [-0.6980,  0.2285,  1.0018, -0.8874]])

tensor([[ 2.3326,  0.4156,  0.0275, -0.0799],
        [ 1.7997,  0.4704, -2.2066, -2.2767],
        [ 1.0018,  0.2285, -0.6980, -0.8874]])
tensor([[0, 3, 1, 2],
        [1, 3, 0, 2],
        [2, 1, 0, 3]])

tensor([[0, 2, 3, 1],
        [2, 0, 3, 1],
        [2, 1, 0, 3]])

其实就是可以通过最后的rank看出对应位置的值的排序位置,这里是降序,所以索引0表示最高

以rank的第一行第一列的值0为例,其表示对应的a中第一行第一列的值2.3326是第一行中最大的,因为设置了dim=1;第一行第三列的值3为例,其表示对应的a中第一行第三列的值-0.0799是第一行中最小的

起到了对应位置排名的作用

 

以上是关于pytorch :使用两次sort函数(排序)找出矩阵每个元素在升序或降序排列中的位置的主要内容,如果未能解决你的问题,请参考以下文章

使用sorted内置函数排序数列来找出最大三个数的乘积

pytorch排序耗时较多

冒泡排序python优化版本

DevOps让金融业数字化转型更敏捷 | 分享实录

找出两个数组中相同的元素,不排序直接两次循环取出

Pytorch如何获得两次失落函数的梯度