使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分

Posted

技术标签:

【中文标题】使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分【英文标题】:Testing and Confidence score of Network trained with nn.CrossEntropyLoss() 【发布时间】:2020-04-28 10:34:16 【问题描述】:

我已经训练了一个具有以下结构的网络:

Intent_LSTM(
(attention): Attention()
(embedding): Embedding(34601, 400)
(lstm): LSTM(400, 512, num_layers=2, batch_first=True, dropout=0.5)
(dropout): Dropout(p=0.5, inplace=False)
(fc): Linear(in_features=512, out_features=3, bias=True)
)

现在我想测试这个训练有素的网络并获得分类的置信度分数。 这是我当前的测试功能实现:

output = model_current(inputs)
pred = torch.round(output.squeeze())
pred = pred.argmax(dim=1, keepdim=True)

现在我的问题如下。

    这里的 pred 只是来自我的网络的全连接层的输出,没有 softmax(根据损失函数的要求)。这是(pred = pred.argmax(dim=1, keepdim=True)) 获得预测的正确方法吗?还是应该将网络的输出传递给 softmax 层,然后进行 argmax?

    如何获得置信度分数?我应该将网络的输出传递到 softmax 层并选择 argmax 作为类的置信度吗?

【问题讨论】:

【参考方案1】:
    在进行 softmax 之前或之后选择 argmax 并不重要。因为最大化softmax的任何东西也会最大化logits(pre-softmax)值。所以你应该得到相似的值。 Softmax 将为您提供每个类别的分数或概率。因此,做softmax之后的值可以作为置信度分数。

【讨论】:

以上是关于使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 自定义损失函数 nn.CrossEntropyLoss

pytorch nn.CrossEntropyLoss() 中的交叉熵损失

Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码

损失函数

小曾带你深入浅出机器学习(小白入门必备,近3万字带你了解机器学习)

训练 LSTM 时跑出 Ram