padding_idx 在 nn.embeddings() 中做了啥

Posted

技术标签:

【中文标题】padding_idx 在 nn.embeddings() 中做了啥【英文标题】:what does padding_idx do in nn.embeddings()padding_idx 在 nn.embeddings() 中做了什么 【发布时间】:2020-07-25 02:20:48 【问题描述】:

我正在学习 pytorch 和 我想知道padding_idx 属性在torch.nn.Embedding(n1, d1, padding_idx=0) 中的作用是什么? 我到处找,找不到我能得到的东西。 你能举个例子来说明这一点吗?

【问题讨论】:

【参考方案1】:

根据docs,每当遇到索引时,padding_idx 都会使用 padding_idx(初始化为零)处的嵌入向量填充输出。

这意味着无论你有一个等于padding_idx 的项目,嵌入层在该索引处的输出将全为零。

这是一个例子: 假设您有 1000 个词的词嵌入,每个词为 50 维,即num_embeddingss=1000embedding_dim=50。然后torch.nn.Embedding 就像查找表一样工作(尽管查找表是可训练的):

emb_layer = torch.nn.Embedding(1000,50)
x = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
y = emb_layer(x)

y 将是一个形状为 2x4x50 的张量。我希望这部分对你来说是清楚的。

现在如果我指定padding_idx=2,即

emb_layer = torch.nn.Embedding(1000,50, padding_idx=2)
x = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
y = emb_layer(x)

那么输出仍将是 2x4x50,但 (1,2) 和 (2,3) 处的 50 维向量将全为零,因为 x[1,2]x[2,3] 的值为 2,等于 padding_idx . 您可以将其视为查找表中的第三个单词(因为查找表将是 0 索引)未用于训练。

【讨论】:

你的意思是 x[0,1] 和 x[1,2] 都为零吗? 我认为他指的是 y[0,1,2] 和 y[1,2,3] 是大小为 50 的零向量。 @Bhashithe 是的。我认为这是一个矩阵并将其读取为 1-indexed(就像人类一样)。我已经编辑了答案,现在将两者都设为 0。 这是否意味着 padding_idx 屏蔽了输入?【参考方案2】:

padding_idx 在documentation 中的描述确实很糟糕。

基本上,它指定在调用期间传递的索引将意味着“零向量”(这在 NLP 中经常使用,以防某些令牌丢失)。默认情况下,没有索引意味着“零向量”,如下例所示:

import torch

embedding = torch.nn.Embedding(10, 3)
input = torch.LongTensor([[0, 1, 0, 5]])
print(embedding(input))

会给你:

tensor([[[ 0.1280, -1.1390, -2.5007],
         [ 0.3617, -0.9280,  1.2894],
         [ 0.1280, -1.1390, -2.5007],
         [-1.3135, -0.0229,  0.2451]]], grad_fn=<EmbeddingBackward>)

如果您在每个input 中指定padding_idx=0,其中的值等于0(因此第零和第二行)将是zero-ed,如下所示(代码:embedding = torch.nn.Embedding(10, 3, padding_idx=0)):

tensor([[[ 0.0000,  0.0000,  0.0000],
         [-0.4448, -0.2076,  1.1575],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.3602, -0.6299, -0.5809]]], grad_fn=<EmbeddingBackward>

如果您要指定padding_idx=5,最后一行将充满零等。

【讨论】:

以上是关于padding_idx 在 nn.embeddings() 中做了啥的主要内容,如果未能解决你的问题,请参考以下文章

分配的变量引用在哪里,在堆栈中还是在堆中?

NOIP 2015 & SDOI 2016 Round1 & CTSC 2016 & SDOI2016 Round2游记

秋的潇洒在啥?在啥在啥?

上传的数据在云端的怎么查看,保存在啥位置?

在 React 应用程序中在哪里转换数据 - 在 Express 中还是在前端使用 React?

存储在 plist 中的数据在模拟器中有效,但在设备中无效