在pytorch中计算困惑度

Posted

技术标签:

【中文标题】在pytorch中计算困惑度【英文标题】:calculate perplexity in pytorch 【发布时间】:2020-03-31 04:59:06 【问题描述】:

我刚刚使用 pytorch 训练了一个 LSTM 语言模型。类的主体是这样的:

class LM(nn.Module):
    def __init__(self, n_vocab, 
                       seq_size, 
                       embedding_size, 
                       lstm_size, 
                       pretrained_embed):

        super(LM, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding.from_pretrained(pretrained_embed, freeze = True)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.fc = nn.Linear(lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)

        return logits, state

现在我想写一个函数来计算一个句子的好坏,基于训练的语言模型(一些分数,比如困惑等)

我有点困惑,我不知道该如何计算。 类似的样本会很有用。

【问题讨论】:

【参考方案1】:

使用交叉熵损失时,您只需使用指数函数torch.exp() 从损失中计算困惑度。(pytorch cross-entropy also uses the exponential function resp. log_n)

所以这里只是一些虚拟的例子:

import torch
import torch.nn.functional as F
num_classes = 10
batch_size  = 1

# your model outputs / logits
output      = torch.rand(batch_size, num_classes) 

# your targets
target      = torch.randint(num_classes, (batch_size,))

# getting loss using cross entropy
loss        = F.cross_entropy(output, target)

# calculating perplexity
perplexity  = torch.exp(loss)
print('Loss:', loss, 'PP:', perplexity)  

在我的例子中,输出是:

Loss: tensor(2.7935) PP: tensor(16.3376)

如果您想获得每个单词丢失所需的每个单词的困惑度,您只需要注意这一点。

这是一个语言模型的简洁示例,它可能会从输出中计算出困惑度:

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/main.py#L30-L50

【讨论】:

感谢您的回答。我的问题有点不同,因为我只想给出一个句子(一个标记列表)作为输入,并得到一个分数作为输出。在这种情况下,我应该将句子和移位句子作为示例代码中的 outputtarget 给出吗? @P.Alipoor 是的,当查看索引为 i 的令牌时,目标应该是令牌 i+1

以上是关于在pytorch中计算困惑度的主要内容,如果未能解决你的问题,请参考以下文章

当目标不是单热时,如何计算 Pytorch 中 2 个张量之间的正确交叉熵?

PyTorch 新手,使用 Data Loader 加载数据后无法进行预测

自然语言推断(NLI)文本相似度相关开源项目推荐(Pytorch 实现)

使用cnn提取特征,图像相似度对比。pytorch 推理的时候报内存不足的问题

计算机视觉PyTorch实现

pytorch 计算成对差异:NumPy 与 PyTorch 和不同 PyTorch 版本的结果不正确