PyTorch按照论文思想实现通道和空间两种注意力机制

Posted 算法与编程之美

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch按照论文思想实现通道和空间两种注意力机制相关的知识,希望对你有一定的参考价值。


from turtle import forward
import torch
from torch import nn



class ChannelAttention(nn.Module):
    
    # ratio表示MLP中,中间层in_planes缩小的比例
    def __init__(self, in_plances, ratio=16) -> None:
        super().__init__()
        
        self.max_pool = nn.AdaptiveMaxPool2d((1,1))
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        
        
        '''
        (1) in_plances / ratio, 其结果为小数,导致模型报错;
        (2) in_plances // ratio, 向下取整;
        (3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能;
        '''
        self.mlp = nn.Sequential( # 此处没有中括号
            nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False), # 此处为什么卷积不需要偏置,是为了模拟FC
            nn.ReLU(),
            nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False) # python 中/与//的区别
        )
        
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, x):
        x1 = self.max_pool(x)
        x1 = self.mlp(x1)
        
        x2 = self.avg_pool(x)
        x2 = self.mlp(x2)
        
        # 此处直接相加,而不是拼接
        # torch.cat(x1, x2)
        
        out = x1 + x2
        out = self.sigmoid(out)
        
        return out





class SpatialAttention(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        
        
        self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()        
        
        
    def forward(self, x):
        '''
        注意,此处并不是简单的最大值和均值池化操作,而是cross channel的;
        '''
        avg_pool = torch.mean(x, dim=1, keepdim=True) # Bx1xHxW
        max_pool, _ = torch.max(x, dim=1, keepdim=True) # Bx1xHxW, 此处非常容易出错,少_
        
        # Bx2xHxW
        out = torch.cat([avg_pool, max_pool], dim=1)
        out = self.conv2d(out)
        out = self.sigmoid(out)
        
        return out
    
    

if __name__ == '__main__':
    from torchinfo import summary
    # import hiddenlayer as h
    
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    print('Channel Attention')
    layer = ChannelAttention(32).to(device)
    summary(layer, (1, 32, 224, 224))
    
    
    print('Spatial Attention')
    layer = SpatialAttention().to(device)
    summary(layer, (1, 32, 224, 224))
    
    
    
    # graph = h.build_graph(layer, torch.zeros([1, 32, 224, 224]))
    # graph.theme = h.graph.THEMES['blue'].copy()
    # graph.save('test.png')
    
    print('done!')

以上是关于PyTorch按照论文思想实现通道和空间两种注意力机制的主要内容,如果未能解决你的问题,请参考以下文章

机器翻译注意力机制及其PyTorch实现

论文笔记/机器学习笔记:CBAM

论文笔记/机器学习笔记:CBAM

SENet代码复现+超详细注释(PyTorch)

用PyTorch实现各种GANs(附论文和代码地址)

资源 | 注意迁移的PyTorch实现