使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分
Posted
技术标签:
【中文标题】使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分【英文标题】:Testing and Confidence score of Network trained with nn.CrossEntropyLoss() 【发布时间】:2020-04-28 10:34:16 【问题描述】:我已经训练了一个具有以下结构的网络:
Intent_LSTM(
(attention): Attention()
(embedding): Embedding(34601, 400)
(lstm): LSTM(400, 512, num_layers=2, batch_first=True, dropout=0.5)
(dropout): Dropout(p=0.5, inplace=False)
(fc): Linear(in_features=512, out_features=3, bias=True)
)
现在我想测试这个训练有素的网络并获得分类的置信度分数。 这是我当前的测试功能实现:
output = model_current(inputs)
pred = torch.round(output.squeeze())
pred = pred.argmax(dim=1, keepdim=True)
现在我的问题如下。
这里的 pred 只是来自我的网络的全连接层的输出,没有 softmax(根据损失函数的要求)。这是(pred = pred.argmax(dim=1, keepdim=True)) 获得预测的正确方法吗?还是应该将网络的输出传递给 softmax 层,然后进行 argmax?
如何获得置信度分数?我应该将网络的输出传递到 softmax 层并选择 argmax 作为类的置信度吗?
【问题讨论】:
【参考方案1】:-
在进行 softmax 之前或之后选择
argmax
并不重要。因为最大化softmax的任何东西也会最大化logits(pre-softmax)值。所以你应该得到相似的值。
Softmax 将为您提供每个类别的分数或概率。因此,做softmax之后的值可以作为置信度分数。
【讨论】:
以上是关于使用 nn.CrossEntropyLoss() 训练的网络的测试和置信度得分的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 自定义损失函数 nn.CrossEntropyLoss
pytorch nn.CrossEntropyLoss() 中的交叉熵损失
Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码