pytorch - torch.gather 的倒数

Posted

技术标签:

【中文标题】pytorch - torch.gather 的倒数【英文标题】:pytorch - reciprocal of torch.gather 【发布时间】:2021-10-01 05:25:54 【问题描述】:

给定一个输入张量x 和一个索引张量idxs,我想检索x 的所有元素,其索引在idxs 中不存在。也就是取torch.gather函数输出的反面。

torch.gather 为例:

>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1,  2,  3],
        [14, 15, 16],
        [27, 28, 29]])

我真正想要实现的是

tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

什么是有效且高效的实现,可能使用 torch 实用程序?我不想使用任何 for 循环。

我假设idxs 在其最深处只有独特 元素。例如idxs 将是调用torch.topk 的结果。

【问题讨论】:

您想要的输出不一致。如果idxs 在同一行有两个相同的元素会发生什么,例如 [[1,1,3], [4,5,6], [7,8,9]]。在这种情况下会产生什么结果? @Ivan 我假设 idxs 在其最深处只有独特的元素。例如,假设 idxstorch.topk 的输出。 【参考方案1】:

您可能希望构造一个形状为(x.size(0), x.size(1)-idxs.size(1))(此处为(3, 7))的张量。这将对应于idxs 的互补索引,关于x 的形状,

tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

我建议首先构建一个形状类似于 x 的张量,它会显示我们想要保留的位置和我们想要丢弃的位置,一种掩码。这可以使用torch.scatter 来完成。这实际上将0s 分散在所需位置,即m[i, idxs[i][j]] = 0

>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

然后抓取非零(idxs 的补充部分)。选择axis=1上的第二个索引,根据目标张量reshape:

>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

现在你知道该怎么做了吧?与您给出的torch.gather 示例相同,但这次使用idxs_

>>> torch.gather(x, 1, idxs_)
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

总结:

>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
        .nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))

>>> torch.gather(x, 1, idxs_)

【讨论】:

您可以使用布尔掩码m = torch.ones_like(x, dtype=torch.bool).scatter(1, idxs, 0) 来简化代码以使用x[m].reshape(-1, x.size(1) - idxs.size(1)) 选择值。 确实,这样更好 - 更自然! - 使用面具的方式。

以上是关于pytorch - torch.gather 的倒数的主要内容,如果未能解决你的问题,请参考以下文章

小白学习之pytorch框架-softmax回归(torch.gather()torch.argmax())

torch.gather()之通俗易懂讲解

pytorch中gather函数的理解。

pytorch 笔记:gather 函数

pytorch-torch2:张量计算和连接

用于音高检测的倒谱分析