为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)

Posted

技术标签:

【中文标题】为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)【英文标题】:why take the first hidden state for sequence classification (DistilBertForSequenceClassification) by HuggingFace为什么采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification) 【发布时间】:2020-05-22 01:33:54 【问题描述】:

在HuggingFace的最后几层序列分类中,他们将transformer输出的序列长度的第一个隐藏状态用于分类。

hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

取第一个隐藏状态比使用最后一个、平均甚至使用 Flatten 层有什么好处吗?

【问题讨论】:

【参考方案1】:

是的,这与 BERT 的训练方式直接相关。具体来说,我鼓励你看看original BERT paper,作者在其中介绍了[CLS]令牌的含义:

[CLS] 是添加在每个输入示例前面的特殊符号 [...]。

具体来说,它用于分类目的,因此是分类任务微调的首选和最简单的选择。您的相关代码片段正在做什么,基本上只是提取这个[CLS] 令牌。

不幸的是,Huggingface 库的 DistilBERT 文档并未明确提及这一点,但您必须查看他们的 BERT documentation,其中他们还强调了 [CLS] 令牌的一些问题,类似于您的担忧:

除了 MLM,BERT 还使用下一句预测 (NSP) 进行训练 目标使用 [CLS] 标记作为序列近似值。用户 可以使用这个标记(使用特殊构建的序列中的第一个标记 令牌)来获得序列预测而不是令牌预测。 但是,对序列进行平均可能会产生比 使用 [CLS] 令牌。

【讨论】:

+1。如果对序列的嵌入进行平均可以产生更好的结果,为什么那些作者没有采用这种方法? 我认为替代方案的计算量更大,因此不值得(可能只是边际)收益。

以上是关于为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)的主要内容,如果未能解决你的问题,请参考以下文章

如何微调 HuggingFace BERT 模型以进行文本分类 [关闭]

EncoderDecoderModel 转换解码器的分类器层

HuggingFace Saving-Loading 模型 (Colab) 进行预测

如何获得 Huggingface Transformer 模型预测 [零样本分类] 的 SHAP 值?

为啥 huggingface bert pooler hack 让混合精度训练稳定?

将 AllenNLP 解释与 HuggingFace 模型一起使用