为啥 huggingface bert pooler hack 让混合精度训练稳定?

Posted

技术标签:

【中文标题】为啥 huggingface bert pooler hack 让混合精度训练稳定?【英文标题】:Why does huggingface bert pooler hack make mixed precission training stable?为什么 huggingface bert pooler hack 让混合精度训练稳定? 【发布时间】:2020-06-29 19:18:54 【问题描述】:

Huggigface BERT 实现有一个 hack,可以从优化器中移除池化器。

https://github.com/huggingface/transformers/blob/b832d5bb8a6dfc5965015b828e577677eace601e/examples/run_squad.py#L927

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

我们正在尝试在拥抱脸伯特模型上运行预训练。如果不应用此 pooler hack,则代码在训练后期总是会发散。我还看到在分类过程中使用了池化层。

pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

池化层是一个带有 tanh 激活的 FFN

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

我的问题是为什么这个 pooler hack 解决了数值不稳定性?

pooler 遇到的问题

【问题讨论】:

【参考方案1】:

有很多资源可能比我更好地解决这个问题,例如here 或here。

具体来说,问题在于您正在处理消失(或爆炸)梯度,特别是当使用在非常小/大输入的任一方向上变平的损失函数时,sigmoid 和 tanh 都是这种情况(唯一的区别这里是它们输出的范围,分别是[0, 1][-1, 1]

此外,如果您有一个低精度的小数,例如 APEX 的情况,那么梯度消失行为很可能已经出现在相对中等的输出中,因为精度限制了它能够区分的数字从零开始。解决这个问题的一种方法是使用具有严格非零且易于计算的导数的函数,例如 Leaky ReLU,或者干脆完全避免激活函数(我假设这是 huggingface 在这里所做的)。

请注意,梯度爆炸的问题通常没有那么悲惨,因为我们可以应用梯度裁剪(将其限制为固定的最大尺寸),但原理是相同的。另一方面,对于零梯度,没有这么简单的解决方法,因为它会导致你的神经元“死亡”(零回流没有发生主动学习),这就是为什么我假设你看到了发散行为。

【讨论】:

将 tanh 激活替换为 GELU 并运行。 loss scaler 变为零的速度更快。 根据this post,我认为GELU 可能也不是一个好的替代品。否则,我只能认为是密集层引起了问题,但我不太知道如何...... 我阅读了博客。但是 GELU 在 BERT 架构中无处不在。此外,传销头使用 GELU 。为什么只有池化层会导致发散,而缩放器在训练后期下降到零?

以上是关于为啥 huggingface bert pooler hack 让混合精度训练稳定?的主要内容,如果未能解决你的问题,请参考以下文章

通过 Huggingface 转换器更新 BERT 模型

BERT HuggingFace 给出 NaN 损失

如何微调 HuggingFace BERT 模型以进行文本分类 [关闭]

如何在 HuggingFace Transformers 库中获取中间层的预训练 BERT 模型输出?

Huggingface Bert:输出打印

huggingface-transformers:训练 BERT 并使用不同的注意力对其进行评估