BYOL算法笔记
Posted AI之路
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了BYOL算法笔记相关的知识,希望对你有一定的参考价值。
论文:Bootstrap your own latent A new approach to self-supervised Learning
链接:https://arxiv.org/abs/2006.07733
代码:https://github.com/deepmind/deepmind-research/tree/master/byol
BYOL发表于NIPS2020,个人觉得是非常棒的一个工作,相信对未来的自监督领域发展会有较大的借鉴意义。
在讲这篇论文之前,先从自监督训练的崩塌问题开始说起。我们知道现在大部分的自监督训练都是通过约束同一张图的不同形态之间的特征差异性来实现特征提取,不同形态一般通过指定的数据增强实现,那么如果只是这么做的话(只有正样本对),网络很容易对所有输入都输出一个固定值,这样特征差异性就是0,完美符合优化目标,但这不是我们想要的,这就是训练崩塌了。因此一个自然的想法是我们不仅仅要拉近相同数据的特征距离,也要拉远不同数据的特征距离,换句话说就是不仅要有正样本对,也要有负样本对,这确实解决了训练崩塌的问题,但是也带来了一个新的问题,那就是对负样本对的数量要求较大,因为只有这样才能训练出足够强的特征提取能力,因此我们可以看到这方面的代表作如SimCLR系列都需要较大的batch size才能有较好的效果。
这篇论文提出的BYOL特点在于没有负样本对,这是一个非常新奇的想法,通过增加prediction和stop-gradient避免训练退化。整体上分为online network和target network两部分,如图Figure2所示,通过约束这2个网络输出特征的均方误差(MSE)来训练online network,而target network的参数更新取决于当前更新后的online network和当前的target network参数,这也就是论文中提到的slow-moving average做法,灵感来源于强化学习。
这篇论文的motivation来源于一个有趣的实验,首先有一个网络参数随机初始化且固定的target network,target network的top1准确率只有1.4%,target network输出feature作为另一个叫online network的训练目标,等这个online network训练好之后,online network的top1准确率可以达到18.8%,这就非常有意思了,假如将target network替换为效果更好的网络参数(比如此时的online network),然后再迭代一次,也就是再训练一轮online network,去学习新的target network输出的feature,那效果应该是不断上升的,类似左右脚踩楼梯不断上升一样。BYOL基本上就是这样做的,并且取得了非常好的效果。
算法伪代码如Algorithm 1所示,整体上还比较容易看懂。
实验结果方面,linear evaluation(特征提取网络的参数不变,仅训练新增的一个线性层参数)的结果还是很不错的,如Table1所示,ResNet-50能达到74.3%的top1 Acc,这个结果甚至要优于相同网络结构的SimCLR v2的结果(71.7%)。
半监督验证结果如Table2所示,在10%标签数据集上的结果基本上和后期的SimCLR v2比较接近,SimCLR v2中比较魔性的实验是添加SK结构后对效果的提升比较明显(基本上至少3~4个点),所以SimCLR v2那篇在网络结构的对比上还是做了很多工作。
BYOL相比SimCLR系列的一个有趣的点在于前者对batch size和数据增强更加鲁棒,论文中也针对这2个方面做了对比实验,如Figure3所示。大batch size对于训练机器要求较高,在SimCLR系列算法中主要起到提供足够的负样本对的作用,而BYOL中没有用到负样本对,因此更加鲁棒。数据增强也是同理,对对比学习的影响比较大,因此这方面BYOL还是很有优势的。
为什么BYOL不会有训练崩塌问题,论文中其实也提到了主要是2方面的原因:1、online network和target network并不是由一个损失函数来共同优化,也就是target network采用了slow-moving average的方式进行参数更新,参考Algorithm中的第11行。2、online network和target network的网络结构并不是完全一样的,online network还多了一个predictor结构。
回到一开始提到的motivation,里面比较重要的一步是target network的参数更新,采用的是Algorithm 1第11行的更新方式,那么这就涉及到计算权重的选择问题,这里作者做了对比实验如Table5(a)所示,T选择1的时候表示target network的参数一直都不变,就是前面motivation提到的结果18.8;T选择0的时候表示target network的参数完全由online network的参数替换,相当于每个step都要更新一下网络参数,可以看到这个时候效果非常差(其实就是训练崩塌),此时整个网络的训练波动比较大。因此中间3种T值的选择就是既不让更新速度过快,也不让更新速度过慢而设计的,整体上在每次更新时原target network的参数占比还是更大的。
在table5(b)中做了关于predictor、target network和是否有负样本对的充分对比实验,在贝塔=0时表示没有负样本对,可以看到此时的SimCLR不管是增加predictor还是target network,效果都非常差,注意看(b)中第一行和倒数第二行的对比,差别只在于有没有predictor,此时效果差异是巨大的,这也就是论文中提到的predictor结构是避免训练崩塌的重要元素之一,个人认为predictor的存在对于BYOL这种更新target network参数的方式而言提供了更多的学习空间,虽然结构很简单,但是不能没有。在贝塔=1时,可以看到是否有predictor对于BYOL和SimCLR都影响不大,可以理解为此时有负样本对保证了训练过程不会崩塌。
备注:T的更新公式,其中k表示step,K是max step。
以上是关于BYOL算法笔记的主要内容,如果未能解决你的问题,请参考以下文章