如何根据 Pytorch 中列表的值定义掩码函数
Posted
技术标签:
【中文标题】如何根据 Pytorch 中列表的值定义掩码函数【英文标题】:How can I define a mask function based on the values of a list in Pytorch 【发布时间】:2022-01-10 23:54:08 【问题描述】:我想根据张量的值屏蔽张量。在下面的函数中,如果我传递一个范围(第二部分)它可以工作,但我想要一个包含各种值的列表prompt_ids
(3、8、9、30)。但它不起作用并引发错误。
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
功能:
def get_prompt_token_fn(self):
if self.prompt_ids:
return lambda x: x in self.prompt_ids
else:
return lambda x: (x>=self.id_offset)&(x<self.id_offset+self.length)
有什么问题,我该如何解决?
【问题讨论】:
【参考方案1】:在pytorch 1.10
中有一个isin
函数,它根据第一个数组的元素在第二个数组中的条件返回一个布尔数组。对于低于它的版本,可以如下实现:
def isin(ar1, ar2):
return (ar1[..., None] == ar2).any(-1)
【讨论】:
以上是关于如何根据 Pytorch 中列表的值定义掩码函数的主要内容,如果未能解决你的问题,请参考以下文章
SqlServer中的数据根据该表中某字段的值的结果决定是不是显示