torch.argmax和argmin返回值

Posted asthnont

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.argmax和argmin返回值相关的知识,希望对你有一定的参考价值。

  在进行深度学习张量计算时,经常要获取张量在某个维度的最大值和最小值,以及这些值的位置。如果只需要知道位置,则torch.argmax和torch.argmin函数便可以实现。

Torch.argmax(input, dim=None, keepdim=False):返回指定维度最大值的序号。

  有时候返回的值比较难理解,所以这里直接放example以帮助理解:

 1 import torch
 2 
 3 t = torch.tensor([[1,2],[3,4],[2,8]])
 4 
 5 print(torch.argmax(t,0))
 6 
 7 
 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],        
 9                             [2,8,4],[7,5,3]]])
10 print(g)
11 print(torch.argmax(g,0))

先从简单的2维张量来看,t 是一个2维张量,大小为(3,2)。t 为 技术图片,此时我们使dim=0,意思使求第0维的(即(3,2)中的3行)中的最大值的序号,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最终的结果为  tensor([1,2])

 


再来看一个3维张量g , tensor([[[1, 2, 3],

              [2, 3, 4],
              [5, 6, 7]],

              [[3, 4, 5],
              [7, 6, 5],
              [5, 4, 3]],

              [[8, 9, 0],
              [2, 8, 4],
              [7, 5, 3]]]),其大小为(3,3,3) 其中我们希望在dim=0的维度中求最大值的序号,则固定第一个维度,第一个维度为channel,则每个channel中对应位置进行比较。

比如每个channel中的(0,0)比较,1<3<8,所以得到的值为2;(0,1)比较,2<4<9,依然得到2,....以此类推。最终得到结果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。

 

以上是关于torch.argmax和argmin返回值的主要内容,如果未能解决你的问题,请参考以下文章

torch.argmax与torch.max详解

Pytorch:解释 torch.argmax

当想要找到具有最高 `start` 分数的标记时,torch.argmax() 中的 TypeError

关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示

小白学习之pytorch框架-softmax回归(torch.gather()torch.argmax())

pandas使用argmin函数返回给定series对象中最小值(minminimum)的行索引实战