torch.argmax和argmin返回值
Posted asthnont
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.argmax和argmin返回值相关的知识,希望对你有一定的参考价值。
在进行深度学习张量计算时,经常要获取张量在某个维度的最大值和最小值,以及这些值的位置。如果只需要知道位置,则torch.argmax和torch.argmin函数便可以实现。
Torch.argmax(input, dim=None, keepdim=False
):返回指定维度最大值的序号。
有时候返回的值比较难理解,所以这里直接放example以帮助理解:
1 import torch 2 3 t = torch.tensor([[1,2],[3,4],[2,8]]) 4 5 print(torch.argmax(t,0)) 6 7 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0], 9 [2,8,4],[7,5,3]]]) 10 print(g) 11 print(torch.argmax(g,0))
先从简单的2维张量来看,t 是一个2维张量,大小为(3,2)。t 为 ,此时我们使dim=0,意思使求第0维的(即(3,2)中的3行)中的最大值的序号,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最终的结果为 tensor([1,2])
再来看一个3维张量g , tensor([[[1, 2, 3],
[2, 3, 4],
[5, 6, 7]],
[[3, 4, 5],
[7, 6, 5],
[5, 4, 3]],
[[8, 9, 0],
[2, 8, 4],
[7, 5, 3]]]),其大小为(3,3,3) 其中我们希望在dim=0的维度中求最大值的序号,则固定第一个维度,第一个维度为channel,则每个channel中对应位置进行比较。
比如每个channel中的(0,0)比较,1<3<8,所以得到的值为2;(0,1)比较,2<4<9,依然得到2,....以此类推。最终得到结果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。
以上是关于torch.argmax和argmin返回值的主要内容,如果未能解决你的问题,请参考以下文章
当想要找到具有最高 `start` 分数的标记时,torch.argmax() 中的 TypeError
关于 decoder_outputs[:,t,:] = decoder_output_t torch.topk, torch.max(),torch.argmax()的演示