目标检测YOLOv5遇上知识蒸馏
Posted zstar-_
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了目标检测YOLOv5遇上知识蒸馏相关的知识,希望对你有一定的参考价值。
前言
模型压缩方法主要4种:
- 网络剪枝(Network pruning)
- 稀疏表示(Sparse representation)
- 模型量化(Model quantification)
- 知识蒸馏(Konwledge distillation)
本文主要来研究知识蒸馏的相关知识,并尝试用知识蒸馏的方法对YOLOv5进行改进。
知识蒸馏理论简介
概述
知识蒸馏(Knowledge Distillation)由深度学习三巨头Hinton在2015年提出。
论文标题:Distilling the knowledge in a neural network
论文地址:https://arxiv.org/pdf/1503.02531.pdf
“蒸馏”是个化工学科中的术语,本身指的是将液体混合物加热沸腾,使其中沸点较低的组分首先变成蒸气,再冷凝成液体,用来分离混合物。而知识蒸馏的含义和蒸馏本身相似但并不完全相同,知识蒸馏指的是同时训练两个网络,一个较复杂的网络作为教师网络,另一个较简单的网络作为学生网络,将教师网络训练得到的结果提炼出来,用来引导学生网络的结果,从而让学生网络学习得更好。
一个公认前提是小模型相比于大模型更容易陷入局部最优,下图[1]中,中间绿色的椭圆表示小网络模型的收敛空间,红色的椭圆表示大网络模型的收敛空间;如果不用知识蒸馏,直接训练小网络,它只会在绿色椭圆区域收敛,而使用知识蒸馏之后,小网络可以收敛到橙色椭圆区域,收敛到更小的最优点。
软标签
有了上面的概念,自然而然想到的一个问题就是,教师模型如何引导学生模型进行学习。这就涉及到论文中提及的一个概念——软标签(Soft target)
如上图[1]所示,以手写数字识别为例,这是一个10分类任务,左边这幅图是采用硬标签(Hard target),输出独热向量,概率最高的类别为1,其它类别为0;右边这幅图采用的是软标签(Soft target),通过softmax层输出的各类别概率,这样的输出具有更高的信息熵,即包含更多信息量。
教师模型输出软标签,从而指导学生模型学习。
softmax的原始公式是这样:
q i = exp ( z i ) ∑ j exp ( z j ) q_i=\\frac\\exp \\left(z_i\\right)\\sum_j \\exp \\left(z_j\\right) qi=∑jexp(zj)exp(zi)
在论文中,作者对这个公式又加以改进,引入了一个新的温度变量T,公式如下:
q i = exp ( z i / T ) ∑ j exp ( z j / T ) q_i=\\frac\\exp \\left(z_i / T\\right)\\sum_j \\exp \\left(z_j / T\\right) qi=∑jexp(zj/T)exp(zi/T)
加入这个变量,能使各类别之间的输出更均衡,如下图[2]所示,T=1为softmax,但是当T过大时,会发现输出向量会趋于一条直线,因此,T通常取中间较小值。
蒸馏温度
上面引入了一个新的变量温度T,这个T也可以称为蒸馏温度,原论文中给出了关于T的进一步讨论,随着T的增加,信息熵会越来越大,如下图[1]所示:
实际上,温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签[1]。
因此,T的取值可以遵循如下策略:
- 当想从负标签中学到一些信息量的时候,温度T应调高一些
- 当想减少负标签的干扰的时候,温度T应调低一些
需要注意的是,这个T只作用于教师网络和学生网络的蒸馏过程,学生网络正常输出仍使用softmax,即T取值为1,就像蒸馏过程一样,需要先进行升温,将知识蒸馏出来,然后输出的时候要冷却降温(T=1)
知识蒸馏过程
从原理上来讲,知识蒸馏没有想象中那么复杂,其流程如下图[1]所示:
- 在T下,训练教师网络得到
soft targets1
- 在T下,训练学生网络得到
soft targets2
- 通过
soft targets1
和soft targets2
得到distillation loss
- 在温度1下,训练学生网络得到
soft targets3
- 通过
soft targets3
和ground truth
得到student loss
通过这五个步骤,就得到了两个损失值 distillation loss
和 student loss
,那么训练的整体损失,就是这两个损失值的加权和,公式[2]如下:
注:
- 这里的蒸馏损失系数乘了一个
T
2
T^2
T2
这是由于soft targets产生的梯度大小按照 1 / T 2 1/T^2 1/T2进行了缩放,这里需要补充回来 -
α
\\alpha
α应远小于
β
\\beta
β
即需要让知识蒸馏损失权重大一些,否则没有蒸馏效果
后面,论文作者分别做了手写数字识别和声音识别实验,这里主要来看作者在MNIST数据集上的实验结果,结果如下表所示:
10xEnsemble是10个教师模型的平均值,Distilled Single model是Baseline模型经过蒸馏之后的结果,可以看到蒸馏出来的准确率提升了1.9%.
YOLOv5加上知识蒸馏
下面就将知识蒸馏融入到YOLOv5目标检测任务中,使用的是YOLOv5-6.0版本。
相关代码参考自:https://github.com/Adlik/yolov5
代码修改
其实知识蒸馏的想法很简单,在仓库作者的代码版本中,修改的内容也并不多,主要是模型加载和损失计算部分。
下面按照顺序来解读一下修改内容。
首先是train_distillation.py
这个文件,通过修改train.py
得到。
新增四个参数:
parser.add_argument('--t_weights', type=str, default='./weights/yolov5s.pt',
help='initial teacher model weights path')
parser.add_argument('--t_cfg', type=str, default='models/yolov5s.yaml', help='teacher model.yaml path')
parser.add_argument('--d_output', action='store_true', default=False,
help='if true, only distill outputs')
parser.add_argument('--d_feature', action='store_true', default=False,
help='if true, distill both feature and output layers')
-
t_weights
教师模型权重,和学生模型加载类似 -
t_cfg
教师模型配置,和学生模型配置类似 -
d_output
这个参数写在这里但不起作用,应该是作者调试时用到的参数,默认是只蒸馏结果 -
d_feature
这个参数默认是关闭,如果开启,蒸馏损失计算将不仅仅是计算两个模型输出的结果,并且中间特征层也会参与计算(不过这个作者没写完整,可能写到一半弃坑了)
模型加载:
这部分需要多加载一个教师模型,相关代码如下:
# Model
check_suffix(weights, '.pt') # check weights
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred len(csd)/len(model.state_dict()) items from weights') # report
# 这里添加加载教师模型
# Teacher model
LOGGER.info(f'Loaded teacher model t_cfg') # report
t_ckpt = torch.load(t_weights, map_location=device) # load checkpoint
t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = t_ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude) # intersect
t_model.load_state_dict(csd, strict=False) # load
损失计算:
这里多了一个d_outputs_loss
,也就是计算蒸馏损失
s_loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)
loss = d_outputs_loss + s_loss
蒸馏损失在loss.py
中进行定义:
def compute_distillation_output_loss(p, t_p, model, d_weight=1):
t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0])
h = model.hyp # hyperparameters
red = 'mean' # Loss reduction (sum or mean)
if red != "mean":
raise NotImplementedError("reduction must be mean in distillation mode!")
DboxLoss = nn.MSELoss(reduction="none")
DclsLoss = nn.MSELoss(reduction="none")
DobjLoss = nn.MSELoss(reduction="none")
# per output
for i, pi in enumerate(p): # layer index, layer predictions
t_pi = t_p[i]
t_obj_scale = t_pi[..., 4].sigmoid()
# BBox
b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)
t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale)
# Class
if model.nc > 1: # cls loss (only if multiple classes)
c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc)
# t_lcls += torch.mean(c_obj_scale * (pi[..., 5:] - t_pi[..., 5:]) ** 2)
t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale)
# t_lobj += torch.mean(t_obj_scale * (pi[..., 4] - t_pi[..., 4]) ** 2)
t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale)
t_lbox *= h['box']
t_lobj *= h['obj']
t_lcls *= h['cls']
# bs = p[0].shape[0] # batch size
loss = (t_lobj + t_lbox + t_lcls) * d_weight
return loss
因为目标检测和原论文中的分类问题有所区别,并不能直接简单套用原论文提出的soft-target,那么这里的处理方式就是将三个损失(位置损失、目标损失、类别损失)简单粗暴地用MSELoss
进行计算,然后蒸馏损失就是这三部分之和。
值得注意的是,理论部分我们提到过,蒸馏损失需要比学生损失的权重更大,因此,这里在计算蒸馏损失中,加入了一个权重d_weight
,权重计算时取10.
下面是代码作者给出的一个实验结果:
Model | Compression strategy | Input size [h, w] | mAPval 0.5:0.95 | Pretrain weight |
---|---|---|---|---|
yolov5s | baseline | [640, 640] | 37.2 | pth | onnx |
yolov5s | distillation | [640, 640] | 39.3 | pth | onnx |
yolov5s | quantization | [640, 640] | 36.5 | xml | bin |
yolov5s | distillation + quantization | [640, 640] | 38.6 | xml | bin |
他采用的是coco数据集,用yolov5m作为教师模型,yolov5s作为学生模型,表格第二行展示了蒸馏之后的效果,mAP提升了2.1.
实验验证
为了验证蒸馏是否有效,我在VisDrone数据集上进行了实验,训练了100epoch,实验结果如下表所示:
Student Model | Teacher Model | Input size [h, w] | mAPtest 0.5 | mAPtest 0.5:0.95 |
---|---|---|---|---|
yolov5m | - | [640, 640] | 0.32 | 0.181 |
yolov5m | yolov5m | [640, 640] | 0.305 | 0.163 |
yolov5m | yolov5x | [640, 640] | 0.302 | 0.161 |
结果挺意外的,使用蒸馏训练之后,mAP反而下降了,严重怀疑蒸馏出来的是糟粕😵
结论
知识蒸馏理论上并不复杂,但经过实验,基本判断这玩意理论价值大于应用价值,用来讲故事可以,实际上提升效果非常有限。当然这是我做了有限实验得出的初步结论,如果读者有更好的思路,可以在评论区留言和我讨论。
TODU
总体而言,这次实验并不算成功,后面会换用其它模型组合以及调整训练参数,再做更多实验。
参考
[1]【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network:https://www.bilibili.com/read/cv16841475
[2]【论文精讲|无废话版】知识蒸馏:https://www.bilibili.com/video/BV1h8411t7SA
以上是关于目标检测YOLOv5遇上知识蒸馏的主要内容,如果未能解决你的问题,请参考以下文章
AAAI2022基于秩模仿和预测引导特征模仿的目标检测知识蒸馏
弹性响应蒸馏 | 用弹性响应蒸馏克服增量目标检测中的灾难性遗忘