在 pytorch 的嵌入层中“究竟”发生了啥?
Posted
技术标签:
【中文标题】在 pytorch 的嵌入层中“究竟”发生了啥?【英文标题】:What "exactly" happens inside embedding layer in pytorch?在 pytorch 的嵌入层中“究竟”发生了什么? 【发布时间】:2020-03-02 06:08:49 【问题描述】:通过多次搜索和 pytorch 文档本身,我可以发现在嵌入层内部有一个查找表,其中存储了嵌入向量。我无法理解的:
-
在这一层的训练过程中究竟发生了什么?
什么是权重以及如何计算这些权重的梯度?
我的直觉是至少应该有一个带有一些参数的函数来生成查找表的键。如果是,那这个函数是什么?
对此的任何帮助将不胜感激。谢谢。
【问题讨论】:
【参考方案1】:这是一个非常好的问题! PyTorch 的嵌入层(Tensorflow 也是如此)用作查找表,仅用于检索每个输入的嵌入,即索引。考虑以下情况,您有一个句子,其中每个单词都被标记化。因此,句子中的每个单词都用一个唯一的整数(索引)表示。如果索引(单词)列表是[1, 5, 9]
,并且您想使用50
维向量(嵌入)对每个单词进行编码,您可以执行以下操作:
# The list of tokens
tokens = torch.tensor([0,5,9], dtype=torch.long)
# Define an embedding layer, where you know upfront that in total you
# have 10 distinct words, and you want each word to be encoded with
# a 50 dimensional vector
embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=50)
# Obtain the embeddings for each of the words in the sentence
embedded_words = embedding(tokens)
现在,回答您的问题:
在前向传递期间,您句子中每个标记的值将以与 Numpy 的索引工作类似的方式获得。因为在后端,这是一个可微分操作,在反向传递(训练)期间,Pytorch 将计算每个嵌入的梯度并相应地重新调整它们。
权重是嵌入本身。词嵌入矩阵其实就是一个权重矩阵,会在训练的时候学习到。
本身没有实际功能。正如我们上面定义的,句子已经被标记化了(每个单词都用一个唯一的整数表示),我们可以获取句子中每个标记的嵌入。
最后,正如我多次提到的索引示例,让我们尝试一下。
# Let us assume that we have a pre-trained embedding matrix
pretrained_embeddings = torch.rand(10, 50)
# We can initialize our embedding module from the embedding matrix
embedding = torch.nn.Embedding.from_pretrained(pretrained_embeddings)
# Some tokens
tokens = torch.tensor([1,5,9], dtype=torch.long)
# Token embeddings from the lookup table
lookup_embeddings = embedding(tokens)
# Token embeddings obtained with indexing
indexing_embeddings = pretrained_embeddings[tokens]
# Voila! They are the same
np.testing.assert_array_equal(lookup_embeddings.numpy(), indexing_embeddings.numpy())
【讨论】:
所以这和one-hot encoding后跟一个线性层完全一样? 没错。我打算这些天有空的时候写一篇博文,我会用链接更新答案。 在您的描述中,您说的是In case the list of indices (words) is [1, 5, 9]
,但您的代码是tokens = torch.tensor([0,5,9],
。为什么从[1,5,9]
更改为[0,5,9]
?
因为当你不仔细检查你写的东西时,你会打错字:) 现在改了:)【参考方案2】:
nn.Embedding
层可以用作查找表。这意味着,如果您有一个包含 n
元素的字典,则可以在创建嵌入时通过 id 调用每个元素。
在这种情况下,字典的大小为 num_embeddings
,embedding_dim
为 1。
在这种情况下,您无需学习任何东西。您可能会说,您只是对字典的元素进行了索引,或者对它们进行了编码。所以在这种情况下不需要前向传递分析。
如果您使用了诸如 Word2vec 之类的词嵌入,您可能已经使用过它。
另一方面,您可以将嵌入层用于分类变量(一般情况下的特征)。在那里,您将嵌入维度 embedding_dim
设置为您可能拥有的类别数。
在这种情况下,您从随机初始化的嵌入层开始,然后学习类别(特征)。
【讨论】:
以上是关于在 pytorch 的嵌入层中“究竟”发生了啥?的主要内容,如果未能解决你的问题,请参考以下文章
当您从 Java servlet 中转发 html 页面时,究竟发生了啥? [复制]