如何根据 Pytorch 中列表的值定义掩码函数

Posted

技术标签:

【中文标题】如何根据 Pytorch 中列表的值定义掩码函数【英文标题】:How can I define a mask function based on the values of a list in Pytorch 【发布时间】:2022-01-10 23:54:08 【问题描述】:

我想根据张量的值屏蔽张量。在下面的函数中,如果我传递一个范围(第二部分)它可以工作,但我想要一个包含各种值的列表prompt_ids(3、8、9、30)。但它不起作用并引发错误。

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

功能:

   def get_prompt_token_fn(self):
        if self.prompt_ids:
            return lambda x: x in self.prompt_ids
        else:
            return lambda x: (x>=self.id_offset)&(x<self.id_offset+self.length)

有什么问题,我该如何解决?

【问题讨论】:

【参考方案1】:

pytorch 1.10 中有一个isin 函数,它根据第一个数组的元素在第二个数组中的条件返回一个布尔数组。对于低于它的版本,可以如下实现:

def isin(ar1, ar2):
    return (ar1[..., None] == ar2).any(-1)

【讨论】:

以上是关于如何根据 Pytorch 中列表的值定义掩码函数的主要内容,如果未能解决你的问题,请参考以下文章

pytorch利用类似掩码的功能把一些值置为0

SqlServer中的数据根据该表中某字段的值的结果决定是不是显示

如何根据下拉列表选择的值更改表中的数据

pytorch torch类

PyTorch 和Albumentations 在图像分割的应用

如何在 PyTorch 中将句子长度批量转换为掩码?