pytorch - torch.gather 的倒数
Posted
技术标签:
【中文标题】pytorch - torch.gather 的倒数【英文标题】:pytorch - reciprocal of torch.gather 【发布时间】:2021-10-01 05:25:54 【问题描述】:给定一个输入张量x
和一个索引张量idxs
,我想检索x
的所有元素,其索引在idxs
中不存在。也就是取torch.gather
函数输出的反面。
以torch.gather
为例:
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])
我真正想要实现的是
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
什么是有效且高效的实现,可能使用 torch 实用程序?我不想使用任何 for 循环。
我假设idxs
在其最深处只有独特 元素。例如idxs
将是调用torch.topk
的结果。
【问题讨论】:
您想要的输出不一致。如果idxs
在同一行有两个相同的元素会发生什么,例如 [[1,1,3], [4,5,6], [7,8,9]]
。在这种情况下会产生什么结果?
@Ivan 我假设 idxs 在其最深处只有独特的元素。例如,假设 idxs 是 torch.topk
的输出。
【参考方案1】:
您可能希望构造一个形状为(x.size(0), x.size(1)-idxs.size(1))
(此处为(3, 7)
)的张量。这将对应于idxs
的互补索引,关于x
的形状,即:
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
我建议首先构建一个形状类似于 x
的张量,它会显示我们想要保留的位置和我们想要丢弃的位置,一种掩码。这可以使用torch.scatter
来完成。这实际上将0
s 分散在所需位置,即m[i, idxs[i][j]] = 0
:
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
然后抓取非零(idxs
的补充部分)。选择axis=1
上的第二个索引,根据目标张量reshape:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
现在你知道该怎么做了吧?与您给出的torch.gather
示例相同,但这次使用idxs_
:
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
总结:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)
【讨论】:
您可以使用布尔掩码m = torch.ones_like(x, dtype=torch.bool).scatter(1, idxs, 0)
来简化代码以使用x[m].reshape(-1, x.size(1) - idxs.size(1))
选择值。
确实,这样更好 - 更自然! - 使用面具的方式。以上是关于pytorch - torch.gather 的倒数的主要内容,如果未能解决你的问题,请参考以下文章