PyTorch按照论文思想实现通道和空间两种注意力机制
Posted 算法与编程之美
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch按照论文思想实现通道和空间两种注意力机制相关的知识,希望对你有一定的参考价值。
from turtle import forward
import torch
from torch import nn
class ChannelAttention(nn.Module):
# ratio表示MLP中,中间层in_planes缩小的比例
def __init__(self, in_plances, ratio=16) -> None:
super().__init__()
self.max_pool = nn.AdaptiveMaxPool2d((1,1))
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
'''
(1) in_plances / ratio, 其结果为小数,导致模型报错;
(2) in_plances // ratio, 向下取整;
(3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能;
'''
self.mlp = nn.Sequential( # 此处没有中括号
nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False), # 此处为什么卷积不需要偏置,是为了模拟FC
nn.ReLU(),
nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False) # python 中/与//的区别
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.max_pool(x)
x1 = self.mlp(x1)
x2 = self.avg_pool(x)
x2 = self.mlp(x2)
# 此处直接相加,而不是拼接
# torch.cat(x1, x2)
out = x1 + x2
out = self.sigmoid(out)
return out
class SpatialAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
'''
注意,此处并不是简单的最大值和均值池化操作,而是cross channel的;
'''
avg_pool = torch.mean(x, dim=1, keepdim=True) # Bx1xHxW
max_pool, _ = torch.max(x, dim=1, keepdim=True) # Bx1xHxW, 此处非常容易出错,少_
# Bx2xHxW
out = torch.cat([avg_pool, max_pool], dim=1)
out = self.conv2d(out)
out = self.sigmoid(out)
return out
if __name__ == '__main__':
from torchinfo import summary
# import hiddenlayer as h
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Channel Attention')
layer = ChannelAttention(32).to(device)
summary(layer, (1, 32, 224, 224))
print('Spatial Attention')
layer = SpatialAttention().to(device)
summary(layer, (1, 32, 224, 224))
# graph = h.build_graph(layer, torch.zeros([1, 32, 224, 224]))
# graph.theme = h.graph.THEMES['blue'].copy()
# graph.save('test.png')
print('done!')
以上是关于PyTorch按照论文思想实现通道和空间两种注意力机制的主要内容,如果未能解决你的问题,请参考以下文章