YOLOv5v7改进之三十二:引入SKAttention注意力机制

Posted 人工智能算法研究院

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了YOLOv5v7改进之三十二:引入SKAttention注意力机制相关的知识,希望对你有一定的参考价值。

 前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv7的如何改进进行详细的介绍,目的是为了给那些搞科研的同学需要创新点或者搞工程项目的朋友需要达到更好的效果提供自己的微薄帮助和参考。由于出到YOLOv7,YOLOv5算法2020年至今已经涌现出大量改进论文,这个不论对于搞科研的同学或者已经工作的朋友来说,研究的价值和新颖度都不太够了,为与时俱进,以后改进算法以YOLOv7为基础,此前YOLOv5改进方法在YOLOv7同样适用,所以继续YOLOv5系列改进的序号。另外改进方法在YOLOv5等其他算法同样可以适用进行改进。希望能够对大家有帮助。

具体改进办法请关注后私信留言!

解决问题:之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入SKAttention注意力机制,该注意力机制借鉴了SENet的思想,通过动态计算每个卷积核得到通道的权重,动态的将各个卷积核的结果进行融合。用于可以让网络更加关注待检测目标,提高检测效果。

基本原理:

      在标准卷积神经网络(CNN)中,每层人工神经元的感受野被设计为共享相同的大小。在神经科学界众所周知,视觉皮层神经元的感受野大小受到刺激的调节,而在构建CNN时很少考虑刺激。我们提出了一种CNN中的动态选择机制,允许每个神经元根据输入信息的多尺度自适应调整其感受野大小。设计了一个称为选择性内核(SK)单元的构建块,其中不同内核大小的多个分支在这些分支中的信息指导下,使用softmax注意力进行融合。对这些分支的不同关注导致融合层神经元的有效感受野大小不同。多个SK单元堆叠在一个称为选择性内核网络(SKNets)的深层网络中。在ImageNet和CIFAR基准测试中,我们的经验表明,SKNet以更低的模型复杂度优于现有的最先进架构。详细分析表明,SKNet中的神经元能够捕获不同尺度的目标物体,这验证了神经元根据输入自适应调整其感受野大小的能力。

 添加方法:

第一步:确定添加的位置,作为即插即用的注意力模块,可以添加到YOLOv5网络中的任何地方。

 第二步:common.py构建SKAttention模块。部分代码如下,关注文章末尾,私信后领取。

class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        self.d = max(L, channel // reduction)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    ('bn', nn.BatchNorm2d(channel)),
                    ('relu', nn.ReLU())
                ]))
            )
        self.fc = nn.Linear(channel, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs = []
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats = torch.stack(conv_outs, 0)  # k,bs,channel,h,w

        ### fuse
        U = sum(conv_outs)  # bs,c,h,w

        ### reduction channel
        S = U.mean(-1).mean(-1)  # bs,c
        Z = self.fc(S)  # bs,d

        ### calculate attention weight
        weights = []
        for fc in self.fcs:
            weight = fc(Z)
            weights.append(weight.view(bs, c, 1, 1))  # bs,channel
        attention_weughts = torch.stack(weights, 0)  # k,bs,channel,1,1
        attention_weughts = self.softmax(attention_weughts)  # k,bs,channel,1,1

第三步:yolo.py中注册 SKAttention模块

        elif m is SKAttention:
            c1, c2 = ch[f], args[0]
            if c2 != no:
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, *args[1:]]

第四步:修改yaml文件,本文以修改head(特征融合网络)为例,将原C3模块后加入该模块。

head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, SKAttention, [1024]],

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

第五步:将train.py中改为本文的yaml文件即可,开始训练。

结 果:本人在遥感数据集上进行实验,有涨点效果。需要请关注留言。

预告一下:下一篇内容将继续分享深度学习算法相关改进方法。有兴趣的朋友可以关注一下我,有问题可以留言或者私聊我哦

PS:该方法不仅仅是适用改进YOLOv5,也可以改进其他的YOLO网络以及目标检测网络,比如YOLOv7、v6、v4、v3,Faster rcnn ,ssd等。

最后,希望能互粉一下,做个朋友,一起学习交流。

以上是关于YOLOv5v7改进之三十二:引入SKAttention注意力机制的主要内容,如果未能解决你的问题,请参考以下文章

PX4模块设计之三十二:AttitudeEstimatorQ模块

Java经典编程题50道之三十二

Qt系列文章之三十二 (基于QThread的QReadWriteLock和QWaitCondition 的线程同步)

Qt系列文章之三十二 (基于QThread的QReadWriteLock和QWaitCondition 的线程同步)

易宝典文章——玩转O365中的EXO服务 之三十二 如何启用和禁用存档

风险模型在时间序列上的改进 ——《系列之三十一》