搞懂PyTorchtorch.argmax() 函数详解

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了搞懂PyTorchtorch.argmax() 函数详解相关的知识,希望对你有一定的参考价值。

torch.argmax 函数详解

1. 函数介绍

torch.argmax(input, dim=None, keepdim=False)

  • 返回指定维度最大值的序号
  • dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。
  1. dim的不同值表示不同维度。特别的在dim=0表示二维中的列,dim=1在二维矩阵中表示行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,依次类推。
  2. 知道dim的值是什么意思还不行,还要知道函数中这个dim给出来会发生什么?

举例说明:

例子1torch.argmax()函数中dim表示该维度会消失。

这个消失是什么意思?

官方英文解释是:dim (int) – the dimension to reduce.

我们知道argmax就是得到最大值的序号索引,对于一个维度为(d0,d1) 的矩阵来说,我们想要求每一行中最大数的在该行中的列号,最后我们得到的就是一个维度为(d0,1) 的一维矩阵。这时候,列这一维度就要消失了。

因此,我们想要求每一行最大的列标号,我们就要指定dim=1,表示我们不要列了,保留行的size就可以了。

假如我们想求每一列的最大行标,就可以指定dim=0,表示我们不要行了,求出每一列的最大值的下标,最后得到(1,d1)的一维矩阵。

2. 实例演示

实例1:

import torch
a = torch.tensor(
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ])
b = torch.argmax(a, dim=0)
print(b)
print(a.shape)

输出结果:

tensor([1, 2, 0, 1])
torch.Size([3, 4])

dim=0的维度为3,即在那3组数据中作比较,求得是每一列中的最大行标,因此为[1,2,0,4]。

实例2:

import torch
a = torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
 
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b = torch.argmax(a, dim=0)
print(b)
print(a.shape)
 
"""
tensor([[0, 1, 0, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
torch.Size([2, 3, 4])"""
 
# dim=0,即将第一个维度消除,也就是将两个[3*4]矩阵只保留一个,因此要在两组中作比较,即将上下两个[3*4]的矩阵分别在对应的位置上比较大小
 
b = torch.argmax(a, dim=1)
"""
tensor([[1, 2, 0, 1],
        [1, 2, 2, 1]])
torch.Size([2, 3, 4])
"""
# dim=1,即将第二个维度消除,这么理解:矩阵维度变为[2*4];
"""
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1];
纵向压缩成一维,因此变为[1,2,0,1];同理得到[1,2,2,1];
"""
b = torch.argmax(a,dim=2)
"""
tensor([[2, 0, 1],
        [1, 0, 2]])
"""
# dim=2,即将第三个维度消除,这么理解:矩阵维度变为[2*3]
"""
   [1, 5, 5, 2],
   [9, -6, 2, 8],
   [-3, 7, -9, 1];
横向压缩成一维
[2,0,1],同理得到下面的"""

参考Link


加油!

感谢!

努力!

以上是关于搞懂PyTorchtorch.argmax() 函数详解的主要内容,如果未能解决你的问题,请参考以下文章

泛函编程(15)-泛函状态-随意数产生器

泛函编程(16)-泛函状态-Functional State

java 类中的建构函式和解构函式名都是啥,java需要手动释放资源吗?

泛函编程(17)-泛函状态-State In Action

一文彻底搞懂 CMS GC 参数配置

学JS的心路历程-函式箭头函式