bert原码解析(embedding)

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了bert原码解析(embedding)相关的知识,希望对你有一定的参考价值。

参考技术A        写这篇文章的起因是看ALBERT的时候,对其中参数因式分解,减少参数的方式不理解,后来通过原码来了解原理。后来想到虽然平时基于bert的nlp任务做的挺多的,但对原理还是一知半解的,所以在此记录。后续有时间的话,将常见的,看过的论文做个总结,不然容易忘记。(attention is all your need,bert,albert,roberta,sentence -bert,simcse,consert,simbert,nezha,ernie,spanbert,gpt,xlnet,tinybert,distillbert)

从图一可以明显看出,bert主要分为三块。embedding层,encoder层,以及pooler层,本章为embedding层的原码分析。

可以看出,输入的input,会先经过tokernizer,会补上cls,sep等特殊字符。然后embedding层会获取句子的token embeddings+segment embeddings+position embeddings作为最终的句子embedding。

1 token embedding:

token embedding有两种初始化方式。如果是训练预训练,随机出初始化一个30522*768的lookup table(根据wordpiece算法,英文一共有30522个sub-word就可以代表所有词汇,每个sub-word 768纬)。如果是在预训练模型的基础上finetune,读取预训练模型训练好的lookup table。假设输入的句子经过tokernized长度为16。经过lookup table就是16*768维的句子表示。

2 position embedding:

position embedding的lookup table 大小512*768,说明bert最长处理长度为512的句子。长于512有几种截断获取的方式。position embedding的生成方式有两种:1 根据公式直接生成 2 根据反向传播计算梯度更新。其中,transformer使用公式直接生成,公式为:

其中,pos指的是这个word在这个句子中的位置;2i指的是embedding词向量的偶数维度,2i+1指的是embedding词向量的奇数维度。为什么这个公式能代表单词在句子中的位置信息呢?因为位置编码基于不同位置添加了正弦波,对于每个维度,波的频率和偏移都有不同。也就是说对于序列中不同位置的单词,对应不同的正余弦波,可以认为他们有相对关系。优点在于减少计算量了,只需要一次初始化不需要后续更新。

其中, bert使用的是根据反向传播计算梯度更新。

3 segment embedding:

bert输入可以为两句话。[cls]....[seq]....[seq]。每句话结尾以seq分割。从embedding的大小可以看出,lookup table由两个768组成,对应第一句和第二句。该参数也由训练得到。

4 LN以及dropout:

embeddings = dropout(layernorm(token embeddings+segment embeddings+position embeddings))。Normalization 有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为 0 方差为1的数据。我们在把数据送入激活函数之前进行normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区,发生梯度消失的问题,使得我们的模型训练变得困难。这里不使用bn可以去除batch size对模型的影响。

下一篇为bert核心encoder模块的解析。

以上是关于bert原码解析(embedding)的主要内容,如果未能解决你的问题,请参考以下文章

BERT模型解析

BERT模型解析

BERT模型解析

Bert4keras解决Key bert/embeddings/word_embeddings not found in checkpoint

Bert4keras解决Key bert/embeddings/word_embeddings not found in checkpoint

BERT 结构与原理(1)--Embedding