Pytorch:解释 torch.argmax
Posted
技术标签:
【中文标题】Pytorch:解释 torch.argmax【英文标题】:Pytorch: explain torch.argmax 【发布时间】:2020-11-08 11:53:38 【问题描述】:您好,我有以下代码:
import torch
x = torch.zeros(1,8,4,576) # create a 4 dimensional tensor
x[0,4,2,333] = 1.0 # put on 1 on a random spot
# I want to find the index of the highest value (0,4,2,333)
print(x.argmax()) # this should return the index
返回
张量(10701)
这个 10701 有什么意义?
如何获得实际索引 0,4,2,333?
【问题讨论】:
【参考方案1】:4维数组中的数据线性存储在内存中,argmax()
返回这个平面表示的对应索引。
Numpy 具有解开索引的功能(从平面数组索引转换为对应的多维索引)。
import numpy as np
np.unravel_index(10701, (1,8,4,576))
【讨论】:
以上是关于Pytorch:解释 torch.argmax的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 1.9.0 backward函数解释以及报错(RuntimeError: grad can be implicitly created only for scalar outputs)
PyTorch里面的torch.nn.Parameter()