如何在 PyTorch 中将句子长度批量转换为掩码?

Posted

技术标签:

【中文标题】如何在 PyTorch 中将句子长度批量转换为掩码?【英文标题】:How to batch convert sentence lengths to masks in PyTorch? 【发布时间】:2019-04-23 12:07:48 【问题描述】:

例如,来自

lens = [3, 5, 4]

我们想要得到

mask = [[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0]]

两者都是torch.LongTensors。

【问题讨论】:

【参考方案1】:

我发现的一种方法是:

torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)

如果有更好的方法请分享!

【讨论】:

【参考方案2】:

只是为了对@ypc的回答提供一点解释(由于缺乏声誉,无法发表评论):

torch.arange(max_len)[None, :] &lt; lens[:, None]

总之,答案使用broadcasting 机制隐式地expand 张量,正如在接受的答案中所做的那样。一步一步:

    torch.arange(max_len) 给你[0, 1, 2, 3, 4];

    添加 [None, :] 会将第 0 维附加到张量,使其形状为 (1, 5),从而得到 [[0, 1, 2, 3, 4]]

    类似地,lens[:, None] 将第一个维度附加到张量lens,使其形状为(3, 1),即[[3], [5], [4]]

    按照broadcasting的规则,通过比较(或做任何类似+、-、*、/等的操作)张量(1, 5)(3, 1),得到的张量形状为(3, 5),结果值为result[i, j] = (j &lt; lens[i])

【讨论】:

【参考方案3】:

torch.arange(max_len)[None, :] &lt; lens[:, None]

【讨论】:

尝试描述你的答案。它将帮助其他人轻松捕获您的代码。请参阅***.com/help/how-to-answer

以上是关于如何在 PyTorch 中将句子长度批量转换为掩码?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Pytorch 中将一维 IntTensor 转换为 int

PyTorch 数据加载器显示字符串数据集的奇怪行为

补充材料

在Java中将句子字符串转换为单词的字符串数组

在python中将cidr转换为子网掩码

在 pytorch 中通过具有线性输出层的 RNN 发送的填充批次的掩码和计算损失