知识蒸馏:Distilling the Knowledge in a Neural Network

Posted AI浩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了知识蒸馏:Distilling the Knowledge in a Neural Network相关的知识,希望对你有一定的参考价值。

文章目录

摘要

提高几乎所有机器学习算法性能的一个非常简单的方法是用相同的数据训练许多不同的模型,然后对它们的预测[3]求平均值。不幸的是,使用整个模型集合进行预测是很麻烦的,而且可能计算成本太高,无法部署到大量用户中,特别是当单个模型是大型神经网络时。Caruana和他的合作者[1]已经证明,可以将集合中的知识压缩到一个更容易部署的单一模型中,我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,我们表明,通过将模型集合中的知识提取到单个模型中,我们可以显著改善一个大量使用的商业系统的声学模型。我们还介绍了一种新的集成类型,它由一个或多个完整模型和许多专业模型组成,这些专业模型学会区分完整模型混淆的细粒度类。不同于专家的混合,这些专家模型可以快速和并行地训练。

1 简介

许多昆虫都有一种幼虫形态,能从环境中提取能量和营养物质,而另一种完全不同的成虫形态则能满足旅行和繁殖的不同需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的需求非常不同:对于语音和对象识别等任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,它可以使用大量的计算。然而,面向大量用户的部署对延迟和计算资源的要求要严格得多。与昆虫的类比表明,如果能更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用非常强的正则化器(如dropout[9])训练的单个非常大的模型。一旦训练了繁琐的模型,我们就可以使用另一种训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型中。这种策略的一个版本已经由Rich Caruana和他的合作者[1]率先提出。在他们的重要论文中,他们令人信服地论证了从大量模型集合中获得的知识可以转移到单个小模型中。

一个概念上的障碍可能阻止了对这一非常有前途的方法进行更多的研究,那就是我们倾向于用学习到的参数值来识别训练过的模型中的知识,这使得我们很难看到如何改变模型的形式而保持相同的知识。知识的一个更抽象的观点是,它从任何特定的实例化中解放出来,它是一个从输入向量到输出向量的学习映射。对于学习区分大量类的繁琐模型,通常的训练目标是使正确答案的平均对数概率最大化,但学习的一个副作用是,训练过的模型为所有的错误答案分配概率,即使这些概率非常小,其中一些也比其他的大得多。错误答案的相对概率告诉我们这个繁琐的模型是如何趋于一般化的。例如,一辆宝马的图像可能只有很小的几率被误认为是一辆垃圾车,但这种错误的可能性仍然比把它误认为是一根胡萝卜的可能性高很多倍。

人们普遍认为,用于训练的目标函数应尽可能准确地反映使用者的真正目标。尽管如此,当真正的目标是很好地推广到新数据时,通常训练模型来优化训练数据的性能。训练模型进行良好的泛化显然更好,但这需要关于泛化的正确方法的信息,而这些信息通常是不可用的。然而,当我们从一个大模型中提取知识到一个小模型中时,我们可以训练小模型以与大模型相同的方式进行归纳。如果繁琐的模型可以很好地泛化,例如,因为它是不同模型的大型集合的平均值,那么用相同方式训练的小模型在测试数据上的表现通常会比在训练集合的相同训练集上以正常方式训练的小模型好得多。人们普遍认为,用于训练的目标函数应尽可能准确地反映使用者的真正目标。尽管如此,当真正的目标是很好地推广到新数据时,通常训练模型来优化训练数据的性能。训练模型进行良好的泛化显然更好,但这需要关于泛化的正确方法的信息,而这些信息通常是不可用的。然而,当我们从一个大模型中提取知识到一个小模型中时,我们可以训练小模型以与大模型相同的方式进行归纳。如果繁琐的模型可以很好地泛化,例如,因为它是不同模型的大型集合的平均值,那么用相同方式训练的小模型在测试数据上的表现通常会比在训练集合的相同训练集上以正常方式训练的小模型好得多。

将繁琐模型的泛化能力转移到小模型上的一种明显的方法是将繁琐模型产生的类概率作为训练小模型的“软目标”。对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集。当复杂模型是较简单模型的大集合时,我们可以使用它们各自预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,它们在每个训练案例中提供的信息要比硬目标多得多,在训练案例之间的梯度方差也要小得多,因此小模型通常可以用比原始繁琐模型少得多的数据进行训练,并使用更高的学习率。

对于像MNIST这样的任务,繁琐的模型几乎总是产生非常高置信度的正确答案,关于学习函数的大部分信息存在于软目标中非常小的概率的比率中。例如,一个版本的2可能被给出 1 0 − 6 10^ - 6 106的概率是3, 1 0 − 9 10 ^- 9 109的概率是7,而另一个版本可能是相反的情况。这是有价值的信息,它定义了数据上丰富的相似结构(例如,它说哪些2看起来像3,哪些看起来像7),但在传递阶段,它对交叉熵代价函数的影响非常小,因为概率非常接近于零。Caruana和他的合作者通过使用logit(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logit和小模型产生的logit之间的平方差。我们更通用的解决方案称为“蒸馏”,即提高最终softmax的温度,直到繁琐的模型产生合适的软目标集。然后我们在训练小模型时使用相同的高温来匹配这些软目标。我们稍后将说明,匹配这个繁琐模型的对数实际上是蒸馏的一种特殊情况。

用于训练小模型的传输集可以完全由未标记的数据[1]组成,也可以使用原始的训练集。我们发现,使用原始训练集效果很好,特别是如果我们在目标函数中添加一个小项,鼓励小模型预测真正的目标,并匹配繁琐模型提供的软目标。通常情况下,小模型不能精确匹配软目标,在正确答案的方向上出错是有帮助的。

2 蒸馏

神经网络通常通过使用“softmax”输出层来产生类概率 z i z_i zi,该输出层通过将 z i z_i zi与其他logit进行比较,计算出每个类的logit 转换为概率 q i q_i qi
q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) (1) q_i=\\frac\\exp \\left(z_i / T\\right)\\sum_j \\exp \\left(z_j / T\\right) \\tag1 qi=jexp(zj/T)exp(zi/T)(1)
其中T是通常设置为1的温度。使用较高的T值会在类上产生较软的概率分布。

在最简单的蒸馏形式中,通过在传输集上对蒸馏模型进行训练,并使用传输集中的每种情况下的软目标分布,将知识转移到蒸馏模型中,该传输集中的软目标分配是通过使用具有最高温度的繁琐模型产生的。在训练蒸馏模型时使用相同的高温,但在训练之后,它使用的温度为1。

当已知所有或部分传输集的正确标签时,还可以通过训练提取模型来产生正确的标签,从而显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,并且该交叉熵是使用蒸馏模型的softmax中与用于从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是带有正确标签的交叉熵。这是使用蒸馏模型的softmax中完全相同的logit进行计算的,但温度为1。我们发现,通常通过在第二个目标函数上使用适当较低的权重来获得最佳结果。由于软目标产生的梯度大小为 1 / T 2 1/T^2 1/T2,因此在使用硬目标和软目标时,将其乘以 T 2 T^2 T2非常重要。这确保了如果在实验元参数时改变用于蒸馏的温度,则硬靶和软靶的相对贡献大致保持不变。

2.1匹配逻辑是蒸馏的特殊情况

传输集中的每种情况都贡献了一个交叉熵梯度, d C / d z i d C / d z_i dC/dzi,相对于蒸馏模型的每个logit, z i z_i zi。如果繁琐的模型具有产生软目标概率 p i p_i pi的逻辑 v i v_i vi,并且转移训练在温度T下进行,则该梯度由下式给出:
∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \\frac\\partial C\\partial z_i=\\frac1T\\left(q_i-p_i\\right)=\\frac1T\\left(\\frace^z_i / T\\sum_j e^z_j / T-\\frace^v_i / T\\sum_j e^v_j / T\\right) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
如果温度与logits的量级相比较高,我们可以近似:
∂ C ∂ z i ≈ 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \\frac\\partial C\\partial z_i \\approx \\frac1T\\left(\\frac1+z_i / TN+\\sum_j z_j / T-\\frac1+v_i / TN+\\sum_j v_j / T\\right) ziCT1(N+jzj/T1+zi/TN+jvj/T1+vi/T)
如果我们现在假设对于每个分动箱,logit分别为零均值,使得 ∑ j z j = ∑ j v j = 0 \\sum_j z_j=\\sum_j v_j=0 jzj=jvj=0。等式3简化为:
∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \\frac\\partial C\\partial z_i \\approx \\frac1N T^2\\left(z_i-v_i\\right) ziCNT以上是关于知识蒸馏:Distilling the Knowledge in a Neural Network的主要内容,如果未能解决你的问题,请参考以下文章

深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network),在线蒸馏

深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network),在线蒸馏

知识蒸馏-Distilling the knowledge in a neural network

知识蒸馏:Distilling the Knowledge in a Neural Network

论文阅读_知识蒸馏_Distilling_BERT

文献阅读——The Augmented Image Prior Distilling 1000 Classes by Extrapolating from a Single Image