Pytorch Tensor 如何获取特定值的索引
Posted
技术标签:
【中文标题】Pytorch Tensor 如何获取特定值的索引【英文标题】:How Pytorch Tensor get the index of specific value 【发布时间】:2018-05-31 11:21:53 【问题描述】:在python列表中,我们可以使用list.index(somevalue)
。 pytorch 如何做到这一点?
例如:
a=[1,2,3]
print(a.index(2))
然后,1
将被输出。 pytorch 张量如何在不将其转换为 python 列表的情况下做到这一点?
【问题讨论】:
这个问题随后在这里得到了回答:***.com/a/51704350/799988 【参考方案1】:我认为没有从 list.index()
直接转换为 pytorch 函数。但是,您可以使用tensor==number
和nonzero()
函数获得类似的结果。例如:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])
这段代码返回
1
[torch.LongTensor 大小为 1x1]
【讨论】:
如果没有匹配会发生什么?你会怎么处理呢? @CharlieParker 如果没有匹配则返回空张量tensor([], dtype=torch.int64)
我们如何扩展它以获得批量索引?在这种情况下,我想要一次索引值torch.Tensor([1, 2, 3])
,而不仅仅是2
。有没有没有for循环的方法?【参考方案2】:
对于多维张量,您可以这样做:
(tensor == target_value).nonzero(as_tuple=True)
生成的张量的形状为number_of_matches x tensor_dimension
。例如,假设 tensor
是一个 3 x 4
张量(这意味着维度是 2),结果将是一个二维张量,其中包含行中匹配项的索引。
tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
[0, 2],
[1, 2]])
【讨论】:
最完整的答案!一般不仅仅是平坦张量。 如果没有匹配会发生什么?你会怎么处理呢? 在这种情况下,你会得到一个空的张量。张量形状仍将遵循输入张量的尺寸,因此在上面的示例中,搜索8
将导致形状为0 x 2
的(空)张量。【参考方案3】:
根据其他人的回答:
t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())
【讨论】:
如果张量只包含一个预期数字的出现,这没关系。那是因为.item()
方法只能在单元素张量上调用,否则会报错。【参考方案4】:
可以通过如下转换为numpy来完成
import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1., 2., 3., 4.])
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
【讨论】:
请注意,如果您转换为 numpy,则会丢失梯度图。 如果没有匹配会发生什么?你会怎么处理呢?【参考方案5】:已经给出的答案很好,但是当我尝试没有匹配时它们无法处理。为此,请参阅:
def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
"""
Returns generalized index (i.e. location/coordinate) of the first occurence of value
in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
If there are multiple occurences then you need to choose which one you want with ith_index.
e.g. ith_index=0 gives first occurence.
Reference: https://***.com/a/67175757/1601580
:return:
"""
# bool tensor of where value occurred
places_where_value_occurs = (tensor == value)
# get matches as a "coordinate list" where occurence happened
matches = (tensor == value).nonzero() # [number_of_matches, tensor_dimension]
if matches.size(0) == 0: # no matches
return -1
else:
# get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
index = matches[ith_match]
return index
感谢这个出色的答案:https://***.com/a/67175757/1601580
【讨论】:
【参考方案6】:用于查找一维张量/数组中元素的索引 示例
mat=torch.tensor([1,8,5,3])
找到5的索引
five=5
numb_of_col=4
for o in range(numb_of_col):
if mat[o]==five:
print(torch.tensor([o]))
要查找 2d/3d 张量的元素索引,将其转换为 1d #ie example.view(元素个数)
例子
mat=torch.tensor([[1,2],[4,3])
#to find index of 2
five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
if mat[o] == five:
print(torch.tensor([o]))
【讨论】:
这是一个老问题,OP 很可能不会寻找快速答案。如果您想发布答案,请花点时间解释您的代码的作用以及它如何添加到已经存在的答案中。现在的代码块对社区几乎没有价值【参考方案7】:对于浮点张量,我用这个来获取张量中元素的索引。
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
这里我想获取max_value在float tensor中的索引,你也可以把你的value这样放到tensor中任意元素的索引。
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
【讨论】:
【参考方案8】: import torch
x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
print(x_data.data[0])
>>tensor([1.])
【讨论】:
以上是关于Pytorch Tensor 如何获取特定值的索引的主要内容,如果未能解决你的问题,请参考以下文章
使用 pyTorch 张量沿一个特定维度和 3 维张量进行索引
Pytorch深度学习实战3-2:什么是张量?Tensor的创建与索引