为啥 Bert 转换器使用 [CLS] 令牌进行分类,而不是对所有令牌进行平均?

Posted

技术标签:

【中文标题】为啥 Bert 转换器使用 [CLS] 令牌进行分类,而不是对所有令牌进行平均?【英文标题】:Why Bert transformer uses [CLS] token for classification instead of average over all tokens?为什么 Bert 转换器使用 [CLS] 令牌进行分类,而不是对所有令牌进行平均? 【发布时间】:2020-10-23 13:54:35 【问题描述】:

我正在对 bert 架构进行实验,发现大多数微调任务都将最终的隐藏层作为文本表示,然后将其传递给其他模型以进行进一步的下游任务。

Bert 的最后一层是这样的:

我们在哪里获取每个句子的 [CLS] 标记:

Image source

我对此huggingface issue、datascience forum question、github issue进行了多次讨论,大多数数据科学家都给出了这样的解释:

BERT 是双向的,[CLS] 被编码,包括所有 所有代币的代表信息通过多层 编码过程。 [CLS] 的表示在 不同的句子。

我的问题是,为什么作者忽略了其他信息(每个token的向量)并取平均,max_pool或其他方法来利用所有信息而不是使用[CLS] token进行分类?

这个 [CLS] 标记与所有标记向量的平均值相比有何帮助?

【问题讨论】:

您还可以返回所有隐藏状态并计算它们的平均/最大池化。我看到很多这样的例子 【参考方案1】:

BERT 主要用于迁移学习,即对特定任务数据集进行微调。如果对状态进行平均,则每个状态的平均权重相同:包括停用词或其他与任务无关的内容。 [CLS] 向量是使用自注意力计算的(就像 BERT 中的所有内容一样),因此它只能从其余隐藏状态中收集相关信息。因此,从某种意义上说,[CLS] 向量也是令牌向量的平均值,只是计算得更巧妙,专门针对您微调的任务。

另外,我的经验是,当我保持权重固定且微调 BERT 时,使用令牌平均值会产生更好的结果。

【讨论】:

【参考方案2】:

[CLS]token 表示整个句子的使用来自original BERT paper,第 3 节:

每个序列的第一个标记始终是一个特殊的分类标记 ([CLS])。这个token对应的最终隐藏状态作为分类任务的聚合序列表示。

你的直觉是正确的,平均所有标记的向量可能会产生更好的结果。事实上,这正是 BertModel 的 Huggingface 文档中提到的内容:

退货

pooler_output(torch.FloatTensor: 形状为(batch_size, hidden_size)):

序列的第一个标记(分类标记)的最后一层隐藏状态进一步由线性层和 Tanh 激活函数处理。线性层权重在预训练期间从下一句预测(分类)目标进行训练。

这个输出通常不能很好地总结输入的语义内容,你通常会更好地对整个输入序列的隐藏状态序列进行平均或合并

更新:Huggingface 在 v3.1.0 中删除了该声明(“此输出通常不是对语义内容的良好总结......”)。你必须问他们为什么。

【讨论】:

也许经过大量实验,这种说法被证明是错误的? 关于 [CLS] 标记的一个愚蠢问题:因此,由于每个输入序列都使用相同的 [CLS] 标记作为序列中的第一个标记,这意味着所有输入都共享相同的嵌入向量输入序列,对吧?那么我们如何将第一个令牌的最终隐藏状态用于以后的分类任务呢?我的意思是,由于 [CLS] 标记的输入嵌入在所有序列中都是共享的,那么在第一个标记的最终隐藏状态中可以表示多少差异? BERT 和其他上下文语言模型中的嵌入不是静态的。 CLS 的嵌入(即实际的 768 个浮点值)将根据输入序列而有所不同,因为它是使用所有输入令牌嵌入的注意力(即加权平均值)计算的。

以上是关于为啥 Bert 转换器使用 [CLS] 令牌进行分类,而不是对所有令牌进行平均?的主要内容,如果未能解决你的问题,请参考以下文章

删除 Bert 中的 SEP 令牌以进行文本分类

用于语义相似性的 BERT 嵌入

有没有办法获取在 BERT 中生成某个令牌的子字符串的位置?

您将使用哪种模型(GPT2、BERT、XLNet 等)进行文本分类任务?为啥?

BERT模型内部结构解析

删除并重新初始化相关的 BERT 权重/参数