8.2 知识蒸馏 讲解 意境级
Posted 炫云云
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了8.2 知识蒸馏 讲解 意境级相关的知识,希望对你有一定的参考价值。
文章目录
相关论文下载:模型压缩方法与bert压缩的论文
知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单,有效,在工业界被广泛应用。这一技术的理论来自于2015年Hinton发表的一篇神作:Distilling the Knowledge in a Neural Network 1
Knowledge Distillation,简称KD,顾名思义,就是将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distilling”)提取到另一个模型里面去。
「温度」: 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?
1、介绍
论文提出的背景:
虽然在一般情况下,我们不会去区分训练和部署使用的模型,但是训练和部署之间存在着一定的不一致性:
-
在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去,常见的瓶颈如下:
-
- 推断速度慢
- 对部署资源要求高(内存,显存等)
-
在部署时,我们对延迟以及计算资源都有着严格的限制。
因此,模型压缩(在保证性能的前提下减少模型的参数量)成为了一个重要的问题。而**「知识蒸馏」**属于模型压缩的一种方法。
插句题外话,我们可以从模型参数量和训练数据量之间的相对关系来理解underfitting和overfitting。AI领域的从业者可能对此已经习以为常,但是为了力求让小白也能读懂本文,还是形象地说明一下:
模型就像一个容器,训练数据中蕴含的知识就像是要装进容器里的水。
- 当数据知识量(水量)超过模型所能建模的范围时(容器的容积,模型的参数量),加再多的数据也不能提升效果(水再多也装不进容器),因为模型的表达空间有限(容器容积有限),就会造成underfitting;
- 而当模型的参数量大于已有知识所需要的表达空间时(容积大于水量,水装不满容器),就会造成overfitting,即模型的bias会增大(想象一下摇晃半满的容器,里面水的形状是不稳定的)。
“思想歧路”
上面容器和水的比喻非常经典和贴切,但是会引起一个误解: 人们在直觉上会觉得,要保留相近的知识量,必须保留相近规模的模型。也就是说,一个模型的参数量基本决定了其所能捕获到的数据内蕴含的“知识”的量。
这样的想法是基本正确的,但是需要注意的是:
- 模型的参数量和其所能捕获的“知识“量之间并非稳定的线性关系(下图中的1),而是接近边际收益逐渐减少的一种增长曲线(下图中的2和3)
- 完全相同的模型架构和模型参数量,使用完全相同的训练数据,能捕获的“知识”量并不一定完全相同,另一个关键因素是训练的方法。合适的训练方法可以使得在模型参数总量比较小时,尽可能地获取到更多的“知识”(下图中的3与2曲线的对比).
2、为什么要有知识蒸馏?
深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,且知识蒸馏是模型压缩中重要的技术之一。
1. 提升模型精度
如果对目前的网络模型A的精度不是很满意,那么可以先训练一个更高精度的teacher模型B(通常参数量更多,时延更大),然后用这个训练好的teacher模型B对student模型A进行知识蒸馏,得到一个更高精度的A模型。
2. 降低模型时延,压缩网络参数
如果对目前的网络模型A的时延不满意,可以先找到一个时延更低,参数量更小的模型B,通常来讲,这种模型精度也会比较低,然后通过训练一个更高精度的teacher模型C来对这个参数量小的模型B进行知识蒸馏,使得该模型B的精度接近最原始的模型A,从而达到降低时延的目的。
3. 标签之间的域迁移
假如使用狗和猫的数据集训练了一个teacher模型A,使用香蕉和苹果训练了一个teacher模型B,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移。
因此,在工业界中对知识蒸馏和迁移学习也有着非常强烈的需求
3、知识蒸馏的理论依据
3.1、Teacher Model和Student Model
知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:
- 原始模型训练: 训练”Teacher模型”, 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对”Teacher模型”不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
- 精简模型训练: 训练”Student模型”, 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。
在本论文中,作者将问题限定在**「分类问题」**下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。
3.2、知识蒸馏的关键点
如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习**「最根本的目的」**在于训练出在某个问题上泛化能力强的模型。
- 泛化能力强: 在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。
而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。
而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。
一个很直白且高效的迁移泛化能力的方法就是使用softmax层输出的类别的概率来作为“soft target”,即Net-S学习“soft target”分布
- 传统training过程(hard targets): 对ground truth求极大似然
- KD的training过程(soft targets): 用large model的class probabilities作为soft targets
- Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。
- Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。
上 图 : Hard Target 下图: Soft Target 上图: \\text{Hard Target 下图: Soft Target} 上图:Hard Target 下图: Soft Target
「为什么soft target进行训练会有效果?」
softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。因为KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式,所以有效果。
举个例子来说明一下: 在手写体数字识别任务MNIST中,输出类别有10个。
假设某个输入的“2”更加形似”3”,softmax的输出值中”3”对应的概率为0.1,而其他负标签对应的值都很小,而另一个”2”更加形似”7”,”7”对应的概率为0.1。这两个”2”对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。
两个”2“的 hard target 相同而 soft target不同 \\text{两个”2“的 hard target 相同而 soft target不同} 两个”2“的 hard target 相同而 soft target不同
这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。
4、知识蒸馏的具体方法
4.1、Logits蒸馏
Logits蒸馏是知识从老师模型的输出的Logit 学习得到;
logits是什么?
softmax层的输入,汇总了网络内部各种信息后,得出的属于各个类别的汇总分值 z i z_i zi,就是 Logits。 i i i 代表第 i i i个类别, z i z_{i} zi 代表属于第 i i i类的可能性。
对于一般的分类问题,比 如图片分类,输入一张图片后,经过DNN网络各种非线性变换,在网络最后 Softmax层之前,会得到这张图片属于各个类别的大小数值 z i z_{i} zi, 某个类别的 z i z_{i} zi 数值越大,则模型认为输入图片属于这个类别的可能性就越大。
softmax函数
因为Logits并非概率值,所以一般在Logits数值上会用Softmax函数进行变换,作为最终分类结果的概率值。 Softmax \\operatorname{Softmax} Softmax 一方面把Logits数值在各类别之间进行概率归一, 使得各个类别归属数值满足概率分布; 另外一方面,它会放大Logits数值之间的差异,使得 Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。
先回顾一下原始的softmax函数:
q
i
=
exp
(
z
i
)
∑
j
exp
(
z
j
)
q_{i}=\\frac{\\exp \\left(z_{i}\\right)}{\\sum_{j} \\exp \\left(z_{j}\\right)}
qi=∑jexp(zj)exp(zi)
但要是直接使用softmax层的输出值作为soft target, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此”温度”这个变量就派上了用场。
下面的公式时加了温度这个变量之后的softmax函数:
q
i
=
exp
(
z
i
/
T
)
∑
j
exp
(
z
j
/
T
)
q_{i}=\\frac{\\exp \\left(z_{i} / T\\right)}{\\sum_{j} \\exp \\left(z_{j} / T\\right)}
qi=∑jexp(zj/T)exp(zi/T)
- 这里的T就是**「温度」**。
- 原来的softmax函数是 T = 1 T = 1 T=1的特例。T越高,softmax的output 概率分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。反之 T T T的温度越低softmax函数就越陡峭。
关于”温度”的讨论
【问题】 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?
随着温度T的增大,概率分布的熵逐渐增大
在回答这个问题之前,先讨论一下**「温度T的特点」**
- 原始的softmax函数是 T = 1 T =1 T=1时的特例, T < 1 T<1 T<1时,概率分布比原始更“陡峭”, T > 1 T>1 T>1 时,概率分布比原始更“平缓”。
- 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i) T = ∞ T=\\infty T=∞, 此时softmax的值是平均分布的;(ii) T → 0 T \\rightarrow 0 T→0 ,此时softmax的值就相当于 argmax \\operatorname{argmax} argmax,即最大的概率处的值趋近于1,而其他值趋近于0)
- 不管温度T怎么取值,Soft target都有忽略小的 q i q_{i} qi携带的信息的倾向
温度代表了什么,如何选取合适温度?
温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。
实际上,负标签中包含一定的信息,尤其是那些值显著**「高于」**平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较经验主义的,本质上就是在下面两件事之中取舍:
- 从有部分信息量的负标签中学习 –> 温度要高一些
- 防止受负标签中噪声的影响 –>温度要低一些
总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能学习到所有的知识,所以可以适当忽略掉一些负标签的信息)
通用的知识蒸馏方法
- 训练好Teacher模型;
- 利用高温 T h i g h T_{h i g h} Thigh 产生 Soft-target
- 使用 { \\left\\{\\right. { Soft − - − target, T high } \\left.T_{\\text {high }}\\right\\} Thigh } 和 { \\{ { Hard − - − target, T = 1 } T=1\\} T=1} 同时训练 Student模型
- 设置温度 T = 1 T=1 T=1, Student模型线上做inference。
模型蒸馏的过程如下图所示,,在知识蒸馏中Teacher模型和Student模型采用同样的温度,训练过程中Teacher将输出的soft label作为结果提供给Student模型进行学习,这就是图中的distillation loss也是(soft targets)。
同时为了避免老师也有出错的时候,学生也会对目标标签进行学习,这也就是图中的student loss(也是hard targets)。
$$ 知识蒸馏示意图 $$下面详细讲讲第二步:高温蒸馏的过程。如上图所示,高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到。
L
=
(
1
−
α
)
L
s
o
f
t
+
α
L
h
a
r
d
L=(1-\\alpha ) L_{s o f t}+\\alpha L_{h a r d}
L=(1−α)Lsoft+αLhard
- v i v_{i} vi: Net-T的logits
- z i : z_{i}: zi: Net-S的logits
- p i T p_{i}^{T} piT: Net-T的在温度=T下的softmax输出在第 i i i类上的值
- q i T q_{i}^{T} qiT: Net-S的在温度=T下的softmax输出在第 i i i类上的值
- c i : c_{i}: ci: 在第i类上的ground truth值, c i ∈ 0 , 1 , c_{i} \\in 0,1, ci∈0,1, 正标签取1,负标签取0.
- N N N : 总标签数量
- Net-T 和 Net-S同时输入 transfer set (这里可以直接复用训练Net-T用到的training set), 用Net-T 在高温 T T T 产生的softmax distribution 来作为soft target,Net-S在相同温度T下的softmax输出和soft target的cross entropy就是**「Loss函数的第一部分」** L s o f t L_{s o f t} Lsoft.
L
s
o
f
t
=
以上是关于8.2 知识蒸馏 讲解 意境级的主要内容,如果未能解决你的问题,请参考以下文章