如何标准化 BERT 分类器的输出
Posted
技术标签:
【中文标题】如何标准化 BERT 分类器的输出【英文标题】:How to normalize output from BERT classifier 【发布时间】:2020-11-16 15:23:54 【问题描述】:我已经使用 HuggingFace transformers.TFBertForSequenceClassification
分类器训练了一个 BERT 分类器。它工作正常,但是当使用 model.predict() 方法时,它会给出一个元组作为输出,该元组在 [0, 1] 之间没有标准化。例如。我训练模型将新闻文章分类为欺诈和非欺诈类别。然后我将以下4个测试数据输入模型进行预测:
articles = ['He was involved in the insider trading scandal.',
'Johnny was a good boy. May his soul rest in peace',
'The fraudster stole money using debit card pin',
'Sun rises in the east']
输出是:
[[-2.8615277, 2.6811066],
[ 2.8651822, -2.564444 ],
[-2.8276567, 2.4451752],
[ 2.770451 , -2.3713884]]
对我来说,label-0 代表非欺诈,label-1 代表欺诈,所以效果很好。但是我如何从这里准备得分信心呢?在这种情况下,使用 softmax 进行归一化是否有意义?另外,如果我想看看模型有点优柔寡断的那些预测,我该怎么做?在那种情况下,这两个值会非常接近吗?
【问题讨论】:
【参考方案1】:是的。你可以使用softmax。更准确地说,在 softmax 上使用 argmax 来获得像 0 或 1 这样的标签预测。
y_pred = tf.nn.softmax(model.predict(test_dataset))
y_pred_argmax = tf.math.argmax(y_pred, axis=1)
当我遇到同样的问题时,这个blog 对我很有帮助..
要回答您的第二个问题,我会要求您专注于您的分类模型错误分类的测试实例,而不是试图找出模型优柔寡断的地方。
因为,argmax 总是会返回 0 或 1,而不是 0.5。而且,我会说标签 0.5 将是声称您的模型优柔寡断的适当值..
【讨论】:
以上是关于如何标准化 BERT 分类器的输出的主要内容,如果未能解决你的问题,请参考以下文章