如何标准化 BERT 分类器的输出

Posted

技术标签:

【中文标题】如何标准化 BERT 分类器的输出【英文标题】:How to normalize output from BERT classifier 【发布时间】:2020-11-16 15:23:54 【问题描述】:

我已经使用 HuggingFace transformers.TFBertForSequenceClassification 分类器训练了一个 BERT 分类器。它工作正常,但是当使用 model.predict() 方法时,它会给出一个元组作为输出,该元组在 [0, 1] 之间没有标准化。例如。我训练模型将新闻文章分类为欺诈和非欺诈类别。然后我将以下4个测试数据输入模型进行预测:

articles = ['He was involved in the insider trading scandal.', 
            'Johnny was a good boy. May his soul rest in peace', 
            'The fraudster stole money using debit card pin', 
            'Sun rises in the east']

输出是:

[[-2.8615277,  2.6811066],
 [ 2.8651822, -2.564444 ],
 [-2.8276567,  2.4451752],
 [ 2.770451 , -2.3713884]]

对我来说,label-0 代表非欺诈,label-1 代表欺诈,所以效果很好。但是我如何从这里准备得分信心呢?在这种情况下,使用 softmax 进行归一化是否有意义?另外,如果我想看看模型有点优柔寡断的那些预测,我该怎么做?在那种情况下,这两个值会非常接近吗?

【问题讨论】:

【参考方案1】:

是的。你可以使用softmax。更准确地说,在 softmax 上使用 argmax 来获得像 0 或 1 这样的标签预测。

y_pred = tf.nn.softmax(model.predict(test_dataset))
y_pred_argmax = tf.math.argmax(y_pred, axis=1)

当我遇到同样的问题时,这个blog 对我很有帮助..


要回答您的第二个问题,我会要求您专注于您的分类模型错误分类的测试实例,而不是试图找出模型优柔寡断的地方。

因为,argmax 总是会返回 0 或 1,而不是 0.5。而且,我会说标签 0.5 将是声称您的模型优柔寡断的适当值..

【讨论】:

以上是关于如何标准化 BERT 分类器的输出的主要内容,如果未能解决你的问题,请参考以下文章

如何定义kNN分类器的最大k?

如何将微调过的 bert 模型的输出作为输入提供给另一个微调过的 bert 模型?

bert-as-service输出分类结果

bert-as-service输出分类结果

逻辑回归推导

如何评估我自己的文本分类器