[论文笔记]DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
Posted 愤怒的可乐
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[论文笔记]DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter相关的知识,希望对你有一定的参考价值。
引言
本文是DistilBERT1的阅读笔记。
核心思想
DistilBERT是一个更小更快的BERT模型,类似ALBERT,也是用来给BERT瘦身的。
DistilBERT应用了基于三重损失(Triplet Loss)的知识蒸馏(knowledge distillation)方法。相比BERT模型,DistilBERT的参数量压缩至原来的40%,同时带来60%的推理速度提升,并且在多个下游任务上达到BERT模型效果的97%。
并且该模型可以放到像手机📲(on-device)这类设备上运行,具备的好处就是更好的隐私保护,一些隐私数据可以不用上传到服务器,直接在手机端针对这些数据就可以为人们带来个性化的服务。
模型剖析
DistilBERT的名字中Distil就是蒸馏的意思,我们先来看下什么是蒸馏。
蒸馏
蒸馏的解释是加热液体汽化,再使蒸气液化,从而除去其中的杂质。
而这里的知识蒸馏是指将已经训练好的模型包含的知识(Knowledge),蒸馏(Distill)提取到另一个模型里面去。
通常前者是一个较大的模型,后者是一个较小的模型。
从另外一个角度思考的话,我们让小模型来学习大模型。因此,我们把大模型当成老师(Teacher),小模型当成学生(Student)2。
蒸馏的目标是让学生模型学习到老师模型的泛化能力,而不是学习拟合训练数据,理论上得到的结果会比单纯拟合训练数据要好。
训练损失
为了将教师模型的知识传输到学生模型,DistilBERT采用了三重损失3:有监督MLM损失、蒸馏MLM损失和词向量余弦损失,如下所示:
L
=
L
s
−
m
l
m
+
L
d
−
m
l
m
+
L
c
o
s
\\mathcal{L}=\\mathcal{L}^{s-mlm} + \\mathcal{L}^{d-mlm}+\\mathcal{L}^{cos}
L=Ls−mlm+Ld−mlm+Lcos
有监督MLM损失 利用掩码语言模型训练得到的损失,即通过输入带有掩码的句子,得到每个掩码位置在词表空间上的概率分布,并利用交叉熵损失函数学习。有监督MLM损失的计算方法为:
L
s
−
m
l
m
=
−
∑
i
y
i
log
(
s
i
)
\\mathcal{L}^{s-mlm}= -\\sum_i y_i \\log (s_i)
Ls−mlm=−i∑yilog(si)
其中,
y
i
y_i
yi表示第
i
i
i个类别的标签;
s
i
s_i
si表示学生模型对该类别的输出概率。
蒸馏MLM损失 利用教师模型的概率作为指导信号,与学生模型的概率计算交叉熵损失进行学习。由于教师模型是已经训练过的预训练语言模型,其输出的概率分布相比学生模型更加准确,能够起到一定的监督训练目的。因此,在预训练语言模型的知识蒸馏中,通常将有监督MLM称作硬标签(Hard Label)训练方法,将蒸馏MLM称作软标签(Soft Label)训练方法。硬标签对应真实的MLM训练标签,而软标签是教师模型输出的概率。蒸馏MLM损失的计算方法为:
L
d
−
m
l
m
=
−
∑
i
t
i
log
(
s
i
)
\\mathcal{L}^{d-mlm} = -\\sum_i t_i \\log(s_i)
Ld−mlm=−i∑tilog(si)
其中,
t
i
t_i
ti表示教师模型对第
i
i
i个类别的输出概率;
s
i
s_i
si表示学生模型对该类别的输出概率。对比上面两个式子可以很容易看出有监督MLM损失和蒸馏MLM损失之间的区别。需要注意的是,当计算概率
t
i
t_i
ti和
s
i
s_i
si时,DistilBERT采用了带有温度系数的Softmax函数:
P
i
=
exp
(
z
i
/
T
)
∑
j
exp
(
z
j
/
T
)
P_i = \\frac{\\exp(z_i/T)}{\\sum_j \\exp(z_j/T)}
Pi=∑jexp(zj/T)exp(zi/T)
其中,
P
i
P_i
Pi表示带有温度的概率值,
t
i
t_i
ti和
s
i
s_i
si均使用该方法计算;
z
i
z_i
zi和
z
j
z_j
zj表示为激活的数值;
T
T
T表示蒸馏里面的温度系数,用于控制输出概率的平滑程度。在训练阶段,教师模型和学生模型设置同样的温度
T
T
T,此时一般将温度系数设为
T
=
8
T=8
T=8。在推理阶段,将温度系数设成
T
=
1
T=1
T=1,还原标准的Softmax函数。
词向量余弦损失 词向量余弦损失用来对齐教师模型和学生模型的隐藏状态向量的方向,从隐藏状态维度拉近教师模型和学生模型的距离,如下:
L
c
o
s
=
cos
(
h
t
,
h
s
)
\\mathcal{L}^{cos} = \\cos(h^t,h^s)
Lcos=cos(ht,hs)
其中,
h
t
h^t
ht和
h
s
h^s
hs分别表示教师模型和学生模型最后一层的隐藏状态输出。
DistilBERT:一个蒸馏版本的BERT
学生模型结构 学生模型(DistilBERT)的基本结构是一个六层的BERT模型,同时去掉了标记类型嵌入和池化模块(Pooler)。线性层和层归一化层已经被高度优化且证明有效,因此作者不改动。最后一层的隐藏向量大小,作者发现减少该值并不太影响模型效果。层数能影响模型效果和推理速度,因此作者注重于此参数优化。
学生模型初始化 教师模型直接使用了原版的BERT-base模型。由于教师模型和学生模型的前六层结构基本相同,为了最大化复用教师模型中的知识,学生模型使用了教师模型的前六层进行初始化。
蒸馏 DistilBERT 是在非常大的批次上使用动态掩码利用梯度累积(每批次最多 4K 个样本)进行蒸馏的,没有下一句预测目标。
评估
GLUE : DistilBERT的参数量压缩至原来的40%,并且在多个下游任务上达到BERT模型效果的97%,甚至在WNLI任务上超过了BERT。
IMDb准确率 :BERT(93.46) DistilBERT(92.82)
推理速度(跑了所有的GLUE任务):DistilBERT(410s) BERT(668s) ELMo(895s)
参考
以上是关于[论文笔记]DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter的主要内容,如果未能解决你的问题,请参考以下文章
使用 huggingface 的 distilbert 模型生成文本