使用 pyTorch 张量沿一个特定维度和 3 维张量进行索引

Posted

技术标签:

【中文标题】使用 pyTorch 张量沿一个特定维度和 3 维张量进行索引【英文标题】:Indexing using pyTorch tensors along one specific dimension with 3 dimensional tensor 【发布时间】:2021-06-10 17:47:46 【问题描述】:

我有 2 个张量:

带有形状(批次、序列、词汇)的A 和 B 的形状(批次、序列)。

A = torch.tensor([[[ 1.,  2.,  3.],
     [ 5.,  6.,  7.]],

    [[ 9., 10., 11.],
     [13., 14., 15.]]])

B = torch.tensor([[0, 2],
    [1, 0]])

我想得到以下内容:

C = torch.zeros_like(B)
for i in range(B.shape[0]):
   for j in range(B.shape[1]):
      C[i,j] = A[i,j,B[i,j]]

但是以矢量化的方式。我尝试了 torch.gather 和其他东西,但我无法让它工作。 谁能帮帮我?

【问题讨论】:

【参考方案1】:
>>> import torch
>>> A = torch.tensor([[[ 1.,  2.,  3.],
...      [ 5.,  6.,  7.]],
... 
...     [[ 9., 10., 11.],
...      [13., 14., 15.]]])
>>> B = torch.tensor([[0, 2],
...     [1, 0]])
>>> A.shape
torch.Size([2, 2, 3])
>>> B.shape
torch.Size([2, 2])
>>> C = torch.zeros_like(B)
>>> for i in range(B.shape[0]):
...    for j in range(B.shape[1]):
...       C[i,j] = A[i,j,B[i,j]]
... 
>>> C
tensor([[ 1,  7],
        [10, 13]])
>>> torch.gather(A, -1, B.unsqueeze(-1))
tensor([[[ 1.],
         [ 7.]],

        [[10.],
         [13.]]])
>>> torch.gather(A, -1, B.unsqueeze(-1)).shape
torch.Size([2, 2, 1])
>>> torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
tensor([[ 1.,  7.],
        [10., 13.]])

您好,您可以使用torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)AB.unsqueeze(-1) 之间的第一个 -1 表示您要沿其选取元素的维度。

B.unsqueeze(-1) 中的第二个 -1 是向 B 添加一个暗度,以使两个张量具有相同的暗度,否则您将得到 RuntimeError: Index tensor must have the same number of dimensions as input tensor

最后一个-1 是将结果从torch.Size([2, 2, 1]) 重塑为torch.Size([2, 2])

【讨论】:

非常感谢。这非常有帮助:)

以上是关于使用 pyTorch 张量沿一个特定维度和 3 维张量进行索引的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch Tensor 维度操作的形象理解 Tensor.unsqueeze() Tensor.squeeze()

pytorch-torch2:张量计算和连接

PyTorch 中的连接张量

pytorch torch类

pytorch中gather函数的理解。

Pytorch 重塑张量维度