8.3 bert的蒸馏讲解 意境级
Posted 炫云云
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了8.3 bert的蒸馏讲解 意境级相关的知识,希望对你有一定的参考价值。
文章目录
相关论文下载:模型压缩方法与bert压缩的论文
最近两年迁移学习在NLP领域发展迅猛,预训练+微调在各项任务上已经取得了很不错的成绩。但是这些模型有一个共同的特点就是参数多,规模大,训练时间长而且成本很高。于是就有了各种方法对bert进行压缩,下面来看看各种方法吧。
相关论文下载:
TinyBERT
TinyBERT1是一种对 BERT 进行知识蒸馏压缩后的模型由华中科技和华为的研究人员提出。
知识蒸馏首先训练一个大的 teacher 模型,然后使用 teacher 模型输出的预测值训练小的 student 模型。student 模型学习 teacher 模型的预测结果 (概率值) 从而学习到 teacher 模型的泛化能力。
之前知识蒸馏的损失函数主要是针对 teacher 模型输出的预测概率值,而 TinyBERT 的损失函数包括四个部分:Embedding 层的损失,Transformer 层 attention 的损失,Transformer 层 hidden state 的损失和最后预测层的损失。即 student 模型不仅仅学习 teacher 模型的预测概率,也学习其 Embedding 层和 Transformer 层的特性。
上面的图片展示了 TinyBERT (studet) 和 BERT (teacher) 的结构,可以看到 TinyBERT 减少了 BERT 的层数,并且减小了 BERT 隐藏层的维度。
TinyBERT 蒸馏过程中的损失函数主要包含以下四个2:
- Embedding 层损失函数
- Transformer 层 attention 损失函数
- Transformer 层 hidden state 损失函数
- 预测层损失函数
我们先看一下 TinyBERT 蒸馏时候每一层的映射方法。
1、TinyBERT 蒸馏的映射方法
假设 TinyBERT 有 M M M 个 Transformer 层,而 BERT 有 N N N 个 Transformer 层。TinyBERT 蒸馏主要涉及的层有 embedding 层 (编号为0)、Transformer 层 (编号为1到M) 和输出层 (编号 M + 1 M+1 M+1)。
我们需要将 TinyBERT 每一层和 BERT 中要学习的层对应起来,然后再蒸馏。对应的函数为 g ( m ) = n g(m) = n g(m)=n, m m m 是 TinyBERT 层的编号, n n n 是 BERT 层的编号。
对于 embedding 层,TinyBERT 蒸馏的时候 embedding 层 (0) 对应了 BERT 的 embedding 层 (0),即 g ( 0 ) = 0 g(0) = 0 g(0)=0。
对于输出层,TinyBERT 的输出层 ( M + 1 ) (M+1) (M+1) 对应了 BERT 的输出层 ( N + 1 ) (N+1) (N+1),即 g ( M + 1 ) = N + 1 g(M+1) = N+1 g(M+1)=N+1。
对于中间的 Transformer 层,TinyBERT 采用 k k k 层蒸馏的方法,即 g ( m ) = m × N / M g(m) = m × N / M g(m)=m×N/M。例如 TinyBERT 有 4 层 Transformer,BERT 有 12 12 12 层 Transformer,则 TinyBERT 第 1 层 Transformer 学习的是 BERT 的第 3 层;而TinyBERT 第 2 2 2 层学习 BERT 的第 6 6 6 层。
2、Embedding 层损失函数
L
embed
=
MSE
(
E
S
W
e
,
E
T
)
E
S
∈
R
l
×
d
′
E
T
∈
R
l
×
d
\\begin{array}{c} L_{\\text {embed }}=\\operatorname{MSE}\\left(\\boldsymbol{E}^{S} \\boldsymbol{W}_{e}, \\boldsymbol{E}^{T}\\right) \\\\ \\boldsymbol{E}^{S} \\in R^{l \\times d^{\\prime}} \\quad \\boldsymbol{E}^{T} \\in R^{l \\times d} \\end{array}
Lembed =MSE(ESWe,ET)ES∈Rl×d′ET∈Rl×d
E
S
E^S
ES 是 TinyBERT 的 embedding,
E
T
E^T
ET 是 BERT 的 embedding,
l
l
l 是句子序列的长度,而
d
‘
d‘
d‘ 是 TinyBERT embedding 维度,
d
d
d 是 BERT embedding 维度。因为是要压缩 BERT 模型,所以
d
′
<
d
d' < d
d′<d,TinyBERT 希望模型学到的 embedding 与 BERT 原来的 embedding 具有相似的语义,因此采用了上面的损失函数,减少两者 embedding 的差异。
embedding 维度不同,不能直接计算 loss,因此 TInyBERT 增加了一个映射矩阵 W e ∈ ( d ′ × d ) W_e\\in (d'×d) We∈(d′×d) 的矩阵, E S E^S ES 乘以映射矩阵后维度与 E T E^T ET 一样。embedding loss 就是二者的均方误差 MSE。
3、Transformer 层 attention 损失函数
TinyBERT 在 Transformer 层损失函数有两个,第一个是 attention loss,如下图所示。
attention loss 主要是希望 TinyBERT Multi-Head Attention 部分输出的 attention score 矩阵 能够接近 BERT 的 attention score 矩阵。因为有研究发现 BERT 学习到的 attention score 矩阵能够包含语义知识,例如语法和相互关系等,具体可参考论文《What Does BERT Look At? An Analysis of BERT’s Attention》。TinyBERT 通过下面的损失函数学习 BERT attention 的功能,
h
h
h 表示 Multi-Head Attention 中 head 的个数。
L
a
t
t
n
=
1
h
∑
i
=
1
h
MSE
(
A
i
S
,
A
i
T
)
L_{\\mathrm{attn}}=\\frac{1}{h} \\sum_{i=1}^{h} \\operatorname{MSE}\\left(\\boldsymbol{A}_{i}^{S}, \\boldsymbol{A}_{i}^{T}\\right)
Lattn=h1i=1∑hMSE(AiS,AiT)
4、Transformer 层 hidden state 损失函数
TinyBERT 在 Transformer 层的第二个损失函数是 hidden loss,如下图所示。
hidden state loss 和 embedding loss 类似,计算公式如下,也需要经过一个映射矩阵。
L
hide
=
MSE
(
H
S
W
h
,
H
T
)
L_{\\text {hide }}=\\operatorname{MSE}\\left(\\boldsymbol{H}^{S} \\boldsymbol{W}_{h}, \\boldsymbol{H}^{T}\\right)
Lhide =MSE(HSWh,HT)
5、预测层损失函数
预测层的损失函数采用了交叉熵,计算公式如下,其中 t 是模型蒸馏的 temperature value,zT 是 BERT 的预测概率,而 zS 是 TinyBERT 的预测概率。
L
pred
=
−
softmax
(
z
T
)
⋅
log
−
softmax
(
z
S
/
t
)
L_{\\text {pred }}=-\\operatorname{softmax}\\left(z^{T}\\right) \\cdot \\log _{-} \\operatorname{softmax}\\left(z^{S} / t\\right)
Lpred =−softmax(zT)⋅log−softmax(zS/t)
6、总体损失函数
L layer ( S m , T g ( m ) ) = { L embd ( S 0 , T 0 ) , m = 0 L hidn ( S m , T g ( m ) ) + L attn ( S m , T g ( m ) ) , M ≥ m > 0 L pred ( S M + 1 , T N + 1 ) , m = M + 1 \\mathcal{L}_{\\text {layer }}\\left(S_{m}, T_{g(m)}\\right)=\\left\\{\\begin{array}{ll} \\mathcal{L}_{\\text {embd }}\\left(S_{0}, T_{0}\\right), & m=0 \\\\ \\mathcal{L}_{\\text {hidn }}\\left(S_{m}, T_{g(m)}\\right)+\\mathcal{L}_{\\text {attn }}\\left(S_{m}, T_{g(m)}\\right), & M \\geq m>0 \\\\ \\mathcal{L}_{\\text {pred }}\\left(S_{M+1}, T_{N+1}\\right), & m=M+1 \\end{array}\\right. Llayer (Sm,Tg(m))=⎩⎨⎧Lembd (S0,T0),Lhidn (S以上是关于8.3 bert的蒸馏讲解 意境级的主要内容,如果未能解决你的问题,请参考以下文章