为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)
Posted
技术标签:
【中文标题】为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)【英文标题】:why take the first hidden state for sequence classification (DistilBertForSequenceClassification) by HuggingFace为什么采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification) 【发布时间】:2020-05-22 01:33:54 【问题描述】:在HuggingFace的最后几层序列分类中,他们将transformer输出的序列长度的第一个隐藏状态用于分类。
hidden_state = distilbert_output[0] # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0] # (bs, dim) <-- first hidden state
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
取第一个隐藏状态比使用最后一个、平均甚至使用 Flatten 层有什么好处吗?
【问题讨论】:
【参考方案1】:是的,这与 BERT 的训练方式直接相关。具体来说,我鼓励你看看original BERT paper,作者在其中介绍了[CLS]
令牌的含义:
[CLS]
是添加在每个输入示例前面的特殊符号 [...]。
具体来说,它用于分类目的,因此是分类任务微调的首选和最简单的选择。您的相关代码片段正在做什么,基本上只是提取这个[CLS]
令牌。
不幸的是,Huggingface 库的 DistilBERT 文档并未明确提及这一点,但您必须查看他们的 BERT documentation,其中他们还强调了 [CLS]
令牌的一些问题,类似于您的担忧:
除了 MLM,BERT 还使用下一句预测 (NSP) 进行训练 目标使用 [CLS] 标记作为序列近似值。用户 可以使用这个标记(使用特殊构建的序列中的第一个标记 令牌)来获得序列预测而不是令牌预测。 但是,对序列进行平均可能会产生比 使用 [CLS] 令牌。
【讨论】:
+1。如果对序列的嵌入进行平均可以产生更好的结果,为什么那些作者没有采用这种方法? 我认为替代方案的计算量更大,因此不值得(可能只是边际)收益。以上是关于为啥采用 HuggingFace 进行序列分类的第一个隐藏状态(DistilBertForSequenceClassification)的主要内容,如果未能解决你的问题,请参考以下文章
如何微调 HuggingFace BERT 模型以进行文本分类 [关闭]
EncoderDecoderModel 转换解码器的分类器层
HuggingFace Saving-Loading 模型 (Colab) 进行预测
如何获得 Huggingface Transformer 模型预测 [零样本分类] 的 SHAP 值?