用于句子多类分类的 BertForSequenceClassification 与 BertForMultipleChoice

Posted

技术标签:

【中文标题】用于句子多类分类的 BertForSequenceClassification 与 BertForMultipleChoice【英文标题】:BertForSequenceClassification vs. BertForMultipleChoice for sentence multi-class classification 【发布时间】:2020-06-21 22:00:36 【问题描述】:

我正在处理文本分类问题(例如情感分析),我需要将文本字符串分类为五个类别之一。

我刚刚开始使用 Huggingface Transformer 包和带有 PyTorch 的 BERT。我需要的是一个顶部有一个 softmax 层的分类器,这样我就可以进行 5 路分类。令人困惑的是,Transformer 包中似乎有两个相关选项:BertForSequenceClassification 和 BertForMultipleChoice。

我应该使用哪一个来完成我的 5 路分类任务?它们有哪些合适的用例?

BertForSequenceClassification 的文档根本没有提到 softmax,尽管它确实提到了交叉熵。我不确定这个类是否仅用于 2 类分类(即逻辑回归)。

Bert 模型转换器,顶部带有序列分类/回归头(池输出顶部的线性层),例如用于 GLUE 任务。

labels(torch.LongTensor of shape (batch_size,),可选,默认为 None)- 用于计算序列分类/回归损失的标签。索引应该在 [0, ..., config.num_labels - 1] 中。如果 config.num_labels == 1 计算回归损失(均方损失),如果 config.num_labels > 1 计算分类损失(交叉熵)。

BertForMultipleChoice的文档中提到了softmax,但是标签的描述方式,听起来这个类是针对多标签分类的(即多标签的二元分类)。

顶部带有多项选择分类头的 Bert 模型(池输出顶部的线性层和 softmax),例如用于 RocStories/SWAG 任务。

labels(torch.LongTensor of shape (batch_size,),可选,默认为 None)- 用于计算多项选择分类损失的标签。索引应该在 [0, ..., num_choices] 中,其中 num_choices 是输入张量的第二维的大小。

感谢您的帮助。

【问题讨论】:

【参考方案1】:

这个问题的答案在于(诚然非常简短的)任务的描述:

[BertForMultipleChoice] [...],例如用于 RocStories/SWAG 任务。

查看paper for SWAG 时,似乎该任务实际上是在学习从不同的选项中进行选择。这与您的“经典”分类任务形成对比,其中“选择”(即类别)在您的样本中不会变化,这正是 BertForSequenceClassification 的用途。

实际上,通过更改配置中的labels 参数,这两种变体都可以用于任意数量的类(在BertForSequenceClassification 的情况下),分别用于选择(BertForMultipleChoice)。但是,由于您似乎正在处理“经典分类”的情况,我建议使用BertForSequenceClassification 模型。

简短地解决BertForSequenceClassification 中缺失的 Softmax:由于分类任务可以计算与样本无关的类之间的损失(与多选不同,您的分布正在变化),这允许您使用交叉熵损失,它会影响Softmax 在 increased numerical stability 的反向传播步骤中。

【讨论】:

谢谢。在情感分析中,问题在于给定一个句子,分类器应该从一组不变的标签(例如posnegneutral)中预测一个标签。在 RocStories 和 SWAG 中,问题在于给定一个句子,分类器应该选择几个句子中最符合第一个的句子。在这里,标签集也可以是一个小的、不变的集(例如ABC)。我不想过多考虑这个问题,所以我将使用BertForSequenceClassification 注意一下,你可以像this一样改变两者的分类层。

以上是关于用于句子多类分类的 BertForSequenceClassification 与 BertForMultipleChoice的主要内容,如果未能解决你的问题,请参考以下文章

多类文本分类:如果输入与类不匹配,则新类

在Tensorflow中限制多类分类中的输出类

如何在多类预测中得到未知类?

如何使用 R e1071 SVM 多类测试数据

用于分类/多类分类的梯度提升树的弱学习器

用于多类分类的 sklearn 指标