如何反转 PyTorch 嵌入?

Posted

技术标签:

【中文标题】如何反转 PyTorch 嵌入?【英文标题】:How to invert a PyTorch Embedding? 【发布时间】:2021-02-07 22:08:47 【问题描述】:

我在 PyTorch 中有一个多任务编码器/解码器模型,在输入端有一个(可训练的)torch.nn.Embedding 嵌入层。

在一项特定任务中,我想对模型进行自我监督预训练(以重建屏蔽输入数据)并将其用于推理(以填补数据空白)。

我想对于训练时间,我可以将损失测量为输入嵌入和输出嵌入之间的距离...但是对于推理,我如何反转 Embedding 以重建输出对应的正确类别/标记?我看不到例如Embedding 类的“最近”函数...

【问题讨论】:

对于invert an Embedding to reconstruct the proper category/token the output corresponds to,您通常会在输出嵌入上添加一个分类器(例如,使用 softmax)来查找预测的标记或类。 【参考方案1】:

你可以很容易地做到这一点:

import torch

embeddings = torch.nn.Embedding(1000, 100)
my_sample = torch.randn(1, 100)
distance = torch.norm(embeddings.weight.data - my_sample, dim=1)
nearest = torch.argmin(distance)

假设您有 1000 维数为 100 的标记,这将返回基于欧几里得距离的最近嵌入。您也可以以类似的方式使用其他指标。

【讨论】:

不应该embeddings(nearest) 等于(或类似)my_sample?我错过了什么吗? 是的,经过培训。这个 sn-p 专注于获取最接近当前样本的嵌入,并假设嵌入在某种程度上接近。

以上是关于如何反转 PyTorch 嵌入?的主要内容,如果未能解决你的问题,请参考以下文章

在 pytorch 的嵌入层中“究竟”发生了啥?

如何正确地为 PyTorch 中的嵌入、LSTM 和线性层提供输入?

pytorch之词嵌入

嵌入pytorch

pytorch中,嵌入层torch.nn.embedding的计算方式

在 Pytorch 中嵌入 3D 数据