8.3 bert的蒸馏讲解 意境级

Posted 炫云云

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了8.3 bert的蒸馏讲解 意境级相关的知识,希望对你有一定的参考价值。

8.1 模型压缩的方法

8.2 知识蒸馏 讲解 意境级

8.3 bert的蒸馏讲解 意境级

8.4 bert的压缩讲解 意境级

相关论文下载:模型压缩方法与bert压缩的论文

最近两年迁移学习在NLP领域发展迅猛,预训练+微调在各项任务上已经取得了很不错的成绩。但是这些模型有一个共同的特点就是参数多,规模大,训练时间长而且成本很高。于是就有了各种方法对bert进行压缩,下面来看看各种方法吧。

相关论文下载:

TinyBERT

TinyBERT1是一种对 BERT 进行知识蒸馏压缩后的模型由华中科技和华为的研究人员提出。

知识蒸馏首先训练一个大的 teacher 模型,然后使用 teacher 模型输出的预测值训练小的 student 模型。student 模型学习 teacher 模型的预测结果 (概率值) 从而学习到 teacher 模型的泛化能力。

之前知识蒸馏的损失函数主要是针对 teacher 模型输出的预测概率值,而 TinyBERT 的损失函数包括四个部分:Embedding 层的损失,Transformer 层 attention 的损失,Transformer 层 hidden state 的损失和最后预测层的损失。即 student 模型不仅仅学习 teacher 模型的预测概率,也学习其 Embedding 层和 Transformer 层的特性。

上面的图片展示了 TinyBERT (studet) 和 BERT (teacher) 的结构,可以看到 TinyBERT 减少了 BERT 的层数,并且减小了 BERT 隐藏层的维度。

TinyBERT 蒸馏过程中的损失函数主要包含以下四个2

  • Embedding 层损失函数
  • Transformer 层 attention 损失函数
  • Transformer 层 hidden state 损失函数
  • 预测层损失函数

我们先看一下 TinyBERT 蒸馏时候每一层的映射方法。

1、TinyBERT 蒸馏的映射方法

假设 TinyBERT 有 M M M 个 Transformer 层,而 BERT 有 N N N 个 Transformer 层。TinyBERT 蒸馏主要涉及的层有 embedding 层 (编号为0)、Transformer 层 (编号为1到M) 和输出层 (编号 M + 1 M+1 M+1)。

我们需要将 TinyBERT 每一层和 BERT 中要学习的层对应起来,然后再蒸馏。对应的函数为 g ( m ) = n g(m) = n g(m)=n m m m 是 TinyBERT 层的编号, n n n 是 BERT 层的编号。

对于 embedding 层,TinyBERT 蒸馏的时候 embedding 层 (0) 对应了 BERT 的 embedding 层 (0),即 g ( 0 ) = 0 g(0) = 0 g(0)=0

对于输出层,TinyBERT 的输出层 ( M + 1 ) (M+1) (M+1) 对应了 BERT 的输出层 ( N + 1 ) (N+1) (N+1),即 g ( M + 1 ) = N + 1 g(M+1) = N+1 g(M+1)=N+1

对于中间的 Transformer 层,TinyBERT 采用 k k k 层蒸馏的方法,即 g ( m ) = m × N / M g(m) = m × N / M g(m)=m×N/M。例如 TinyBERT 有 4 层 Transformer,BERT 有 12 12 12 层 Transformer,则 TinyBERT 第 1 层 Transformer 学习的是 BERT 的第 3 层;而TinyBERT 第 2 2 2 层学习 BERT 的第 6 6 6 层。

2、Embedding 层损失函数

L embed  = MSE ⁡ ( E S W e , E T ) E S ∈ R l × d ′ E T ∈ R l × d \\begin{array}{c} L_{\\text {embed }}=\\operatorname{MSE}\\left(\\boldsymbol{E}^{S} \\boldsymbol{W}_{e}, \\boldsymbol{E}^{T}\\right) \\\\ \\boldsymbol{E}^{S} \\in R^{l \\times d^{\\prime}} \\quad \\boldsymbol{E}^{T} \\in R^{l \\times d} \\end{array} Lembed =MSE(ESWe,ET)ESRl×dETRl×d
E S E^S ES 是 TinyBERT 的 embedding, E T E^T ET 是 BERT 的 embedding, l l l句子序列的长度,而 d ‘ d‘ d 是 TinyBERT embedding 维度, d d d 是 BERT embedding 维度。因为是要压缩 BERT 模型,所以 d ′ < d d' < d d<d,TinyBERT 希望模型学到的 embedding 与 BERT 原来的 embedding 具有相似的语义,因此采用了上面的损失函数,减少两者 embedding 的差异。

embedding 维度不同,不能直接计算 loss,因此 TInyBERT 增加了一个映射矩阵 W e ∈ ( d ′ × d ) W_e\\in (d'×d) We(d×d) 的矩阵, E S E^S ES 乘以映射矩阵后维度与 E T E^T ET 一样。embedding loss 就是二者的均方误差 MSE。

3、Transformer 层 attention 损失函数

TinyBERT 在 Transformer 层损失函数有两个,第一个是 attention loss,如下图所示。

attention loss 主要是希望 TinyBERT Multi-Head Attention 部分输出的 attention score 矩阵 能够接近 BERT 的 attention score 矩阵。因为有研究发现 BERT 学习到的 attention score 矩阵能够包含语义知识,例如语法和相互关系等,具体可参考论文《What Does BERT Look At? An Analysis of BERT’s Attention》。TinyBERT 通过下面的损失函数学习 BERT attention 的功能, h h h 表示 Multi-Head Attention 中 head 的个数。
L a t t n = 1 h ∑ i = 1 h MSE ⁡ ( A i S , A i T ) L_{\\mathrm{attn}}=\\frac{1}{h} \\sum_{i=1}^{h} \\operatorname{MSE}\\left(\\boldsymbol{A}_{i}^{S}, \\boldsymbol{A}_{i}^{T}\\right) Lattn=h1i=1hMSE(AiS,AiT)

4、Transformer 层 hidden state 损失函数

TinyBERT 在 Transformer 层的第二个损失函数是 hidden loss,如下图所示。

hidden state loss 和 embedding loss 类似,计算公式如下,也需要经过一个映射矩阵。
L hide  = MSE ⁡ ( H S W h , H T ) L_{\\text {hide }}=\\operatorname{MSE}\\left(\\boldsymbol{H}^{S} \\boldsymbol{W}_{h}, \\boldsymbol{H}^{T}\\right) Lhide =MSE(HSWh,HT)

5、预测层损失函数

预测层的损失函数采用了交叉熵,计算公式如下,其中 t 是模型蒸馏的 temperature value,zT 是 BERT 的预测概率,而 zS 是 TinyBERT 的预测概率。
L pred  = − softmax ⁡ ( z T ) ⋅ log ⁡ − softmax ⁡ ( z S / t ) L_{\\text {pred }}=-\\operatorname{softmax}\\left(z^{T}\\right) \\cdot \\log _{-} \\operatorname{softmax}\\left(z^{S} / t\\right) Lpred =softmax(zT)logsoftmax(zS/t)

6、总体损失函数

L layer  ( S m , T g ( m ) ) = { L embd  ( S 0 , T 0 ) , m = 0 L hidn  ( S m , T g ( m ) ) + L attn  ( S m , T g ( m ) ) , M ≥ m > 0 L pred  ( S M + 1 , T N + 1 ) , m = M + 1 \\mathcal{L}_{\\text {layer }}\\left(S_{m}, T_{g(m)}\\right)=\\left\\{\\begin{array}{ll} \\mathcal{L}_{\\text {embd }}\\left(S_{0}, T_{0}\\right), & m=0 \\\\ \\mathcal{L}_{\\text {hidn }}\\left(S_{m}, T_{g(m)}\\right)+\\mathcal{L}_{\\text {attn }}\\left(S_{m}, T_{g(m)}\\right), & M \\geq m>0 \\\\ \\mathcal{L}_{\\text {pred }}\\left(S_{M+1}, T_{N+1}\\right), & m=M+1 \\end{array}\\right. Llayer (Sm,Tg(m))=Lembd (S0,T0),Lhidn (S以上是关于8.3 bert的蒸馏讲解 意境级的主要内容,如果未能解决你的问题,请参考以下文章

8.2 知识蒸馏 讲解 意境级

8.4 bert的压缩讲解 意境级

6.9意境级讲解BERT更好的进行微调方法总结

21 意境级讲解 共指消解的方法

10.1 意境级讲解关系抽取

二分查找算法算法指导 意境级讲解