PyTorch:除前k之外的向量的所有元素都归零?

Posted

技术标签:

【中文标题】PyTorch:除前k之外的向量的所有元素都归零?【英文标题】:PyTorch: Zero all elements of vector except top k? 【发布时间】:2021-04-14 22:36:07 【问题描述】:

我正在尝试创建一个新的激活层,我们称之为 topk,它的工作方式如下。它将一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并添加偏差的结果)和一个正整数 k,并将输出一个大小为 n 的向量 topk(x),其元素是:

              x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
              0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。

我应该如何实现这个?任何帮助将不胜感激。

【问题讨论】:

【参考方案1】:

您可以为此使用torch.topk

k = 2
output = torch.randn(5)
vals, idx = output.topk(k)

topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557, 0.0000, 0.0000, 1.4562, 0.0000])

请注意,虽然 topk()'values' 是可微分的,但 'indices' are not (类似于 argmax 是不可微分的函数)。

【讨论】:

以上是关于PyTorch:除前k之外的向量的所有元素都归零?的主要内容,如果未能解决你的问题,请参考以下文章

使用 JQuery 从列表中删除除前 N 个元素之外的所有元素

选择向量中除一个之外的所有元素

如何从表中删除除前两个和最后一个之外的所有行?

使用 :not(selector) 选择除少数之外的所有元素 [重复]

从 EventHandler 中删除除前两个参数外的所有参数

在除前两列之外的每列上前向填充具有最新非空值的空值