BERT知识蒸馏TinyBERT
Posted zhiyong_will
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了BERT知识蒸馏TinyBERT相关的知识,希望对你有一定的参考价值。
1. 概述
诸如BERT等预训练模型的提出显著的提升了自然语言处理任务的效果,但是随着模型的越来越复杂,同样带来了很多的问题,如参数过多,模型过大,推理事件过长,计算资源需求大等。近年来,通过模型压缩的方式来减小模型的大小也是一个重要的研究方向,其中,知识蒸馏也是常用的一种模型压缩方法。TinyBERT[1]是一种针对transformer-based模型的知识蒸馏方法,以BERT为Teacher模型蒸馏得到一个较小的模型TinyBERT。四层结构的TinyBERT在GLUE benchmark上可以达到BERT的96.8%及以上的性能表现,同时模型缩小7.5倍,推理速度提升9.4倍。六层结构的TinyBERT可以达到和BERT同样的性能表现。
2. 算法原理
为了能够将原始的BERT模型蒸馏到TinyBERT,因此,在[1]中提出了一种新的针对Transformer网络特殊设计的蒸馏方法,同时,因为BERT模型的训练分成了两个部分,分别为预训练和针对特定任务的Fine-tuning,因此在TinyBERT模型的蒸馏训练过程中也设计了两阶段的学习框架,在预训练和Fine-tuning阶段都进行蒸馏,以确保TinyBERT模型能够从BERT模型中学习到一般的语义知识和特定任务知识。
2.1. 知识蒸馏
知识蒸馏(knowledge distillation)[2]是模型压缩的一种常用的方法,对于一个完整的知识蒸馏过程,有两个模型,分别为Teacher模型和Student模型,通过学习将已经训练好的Teacher模型中的知识迁移到小的Student模型中。其具体过程如下图所示:
对于Student模型,其目标函数有两个,分别为蒸馏的loss(distillation loss)和自身的loss(student loss),Student模型最终的损失函数为:
L = α L s o f t + β L h a r d L=\\alpha L_soft+\\beta L_hard L=αLsoft+βLhard
其中, L s o f t L_soft Lsoft表示的是蒸馏的loss, L h a r d L_hard Lhard表示的是自身的loss。
2.2. Transformer Distillation
BERT模型是由多个Transformer模块(Self-Attention+FFN)组成,单个Self-Attention+FFN模块如下图所示:
假设BERT模型中有 N N N层的Transformer Layer,在蒸馏的过程中,BERT模型作为Teacher模型,而需要蒸馏的模型TinyBERT模型作为Student模型,其Transformer Layer的层数假设为 M M M,则有 M < N M<N M<N,此时需要找到一个对应关系: n = g ( m ) n = g\\left ( m \\right ) n=g(m),表示的是在Student模型中的第 m m m层对应于Teacher模型中的第 n n n层,即 g ( m ) g\\left ( m \\right ) g(m)层。TinyBERT的Embedding层和预测层也是从BERT的相应层学习知识的,其中Embedding层对应的层数为 0 0 0,预测层对应的层数为 M + 1 M+1 M+1,对应到BERT中的层数分别为 0 = g ( 0 ) 0=g\\left (0 \\right ) 0=g(0) 和 N + 1 = g ( M + 1 ) N + 1 = g\\left ( M+1 \\right ) N+1=g(M+1)。在形式上,学生模型可以通过最小化以下的目标函数来获取教师模型的知识:
L m o d e l = ∑ x ∈ χ ∑ m = 0 M + 1 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) L_model=\\sum _x\\in \\chi \\sum_m=0^M+1\\lambda _mL_layer\\left ( f_m^S\\left ( x \\right ),f_g\\left ( m \\right )^T\\left ( x \\right ) \\right ) Lmodel=x∈χ∑m=0∑M+1λmLlayer(fmS(x),fg(m)T(x))
其中, L l a y e r L_layer Llayer是给定的模型层的损失函数, f m f_m fm表示的是由第 m m m层得到的结果, λ m \\lambda_m λm表示第 m m m层蒸馏的重要程度。在TinyBERT的蒸馏过程中,又可以分为以下三个部分:
- transformer-layer distillation
- embedding-layer distillation
- prediction-layer distillation。
2.2.1. Transformer-layer Distillation
Transformer-layer的蒸馏由Attention Based蒸馏和Hidden States Based蒸馏两部分组成,具体如下图所示:
其中,在BERT中多头注意力层能够捕获到丰富的语义信息,因此,在蒸馏到TinyBERT中,提出了Attention Based蒸馏,其目的是希望使得蒸馏后的Student模型能够从Teacher模型中学习到这些语义上的信息。具体到模型中,就是让TinyBERT网络学习拟合BERT网络中的多头注意力矩阵,目标函数定义如下:
L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) L_attn=\\frac1h\\sum_i=1^hMSE\\left ( A_i^S,A_i^T \\right ) Lattn=h1i=1∑hMSE(AiS,AiT)
其中, h h h代表注意力头数, A i ∈ R l × l A_i \\in \\mathbbR^l\\times l Ai∈Rl×l代表Student或者Teacher模型中的第 i i i个注意力头对应的注意力矩阵, l l l代表输入文本的长度。在[1]中使用注意力矩阵 A A A而不是 s o f t m a x ( A ) softmax\\left ( A \\right ) softmax(A)是因为实验结果显示这样可以得到更快的收敛速度和更好的性能表现。
Hidden States Based的蒸馏是对Transformer层进行了知识蒸馏处理,目标函数定义为:
L h i d n = M S E ( H S W h , H T ) L_hidn=MSE\\left ( H^SW_h,H^T \\right ) Lhidn=MSE(HSWh,HT)
其中,矩阵 H S ∈ R l × d ′ H^S\\in \\mathbbR^l\\times d' HS∈Rl×d′和 H T ∈ R l × d H^T\\in \\mathbbR^l\\times d HT∈Rl×d分别代表Student网络和Teacher网络的隐状态,且都是FFN的输出。 d d d和 d ′ d' d′代表Teacher网络和Student网络的隐藏状态大小,且 d ′ < d d' < d d′<d,因为Student网络总是小于Teacher网络。 W h ∈ R d ′ × d W_h\\in \\mathbbR^d'\\times d Wh∈Rd′×d是一个参数矩阵,将Student网络的隐藏状态投影到Teacher网络隐藏状态所在的空间。
2.2.2. Embedding-layer Distillation
Embedding层的蒸馏与Hidden States Based蒸馏一致,其目标函数为:
L e m b d = M S E ( E S W e , E T ) L_embd=MSE\\left ( E^SW_e,E^T \\right ) Lembd=MSE(ESWe,ET)
其中 E S E^S ES, E T E^T ET分别代表Student网络和Teacher网络的Embedding, W e W_e We的作用与 W h W_h Wh的作用一致。
2.2.3. Prediction-layer Distillation
除了对中间层做蒸馏,同样对于最终的预测层也要进行蒸馏,其目标函数为:
L
p
r
e
d
=
C
E
(
z
T
t
,
z
S
t
)
L_pred=CE\\left ( \\fracz^Tt,\\fracz^St \\right )
以上是关于BERT知识蒸馏TinyBERT的主要内容,如果未能解决你的问题,请参考以下文章