如何在 PyTorch 中将句子长度批量转换为掩码?
Posted
技术标签:
【中文标题】如何在 PyTorch 中将句子长度批量转换为掩码?【英文标题】:How to batch convert sentence lengths to masks in PyTorch? 【发布时间】:2019-04-23 12:07:48 【问题描述】:例如,来自
lens = [3, 5, 4]
我们想要得到
mask = [[1, 1, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 0]]
两者都是torch.LongTensor
s。
【问题讨论】:
【参考方案1】:我发现的一种方法是:
torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)
如果有更好的方法请分享!
【讨论】:
【参考方案2】:只是为了对@ypc的回答提供一点解释(由于缺乏声誉,无法发表评论):
torch.arange(max_len)[None, :] < lens[:, None]
总之,答案使用broadcasting
机制隐式地expand
张量,正如在接受的答案中所做的那样。一步一步:
torch.arange(max_len) 给你[0, 1, 2, 3, 4]
;
添加 [None, :]
会将第 0 维附加到张量,使其形状为 (1, 5)
,从而得到 [[0, 1, 2, 3, 4]]
;
类似地,lens[:, None]
将第一个维度附加到张量lens
,使其形状为(3, 1)
,即[[3], [5], [4]]
;
按照broadcasting
的规则,通过比较(或做任何类似+、-、*、/等的操作)张量(1, 5)
和(3, 1)
,得到的张量形状为(3, 5)
,结果值为result[i, j] = (j < lens[i])
。
【讨论】:
【参考方案3】:torch.arange(max_len)[None, :] < lens[:, None]
【讨论】:
尝试描述你的答案。它将帮助其他人轻松捕获您的代码。请参阅***.com/help/how-to-answer以上是关于如何在 PyTorch 中将句子长度批量转换为掩码?的主要内容,如果未能解决你的问题,请参考以下文章