Swin Transformer代码实现部分细节重点
Posted weixin_44040169
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Swin Transformer代码实现部分细节重点相关的知识,希望对你有一定的参考价值。
swin transformer
1.patch-merging部分
代码:【amazing】
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] 对应图片所有 1 的位置
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] 对应图片所有 3 的位置
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] 对应图片所有 2 的位置
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] 对应图片所有 4 的位置
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C] 拼在一起,通道变为4倍
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C] self.reduction = nn.Linear(4*dim, 2*dim, bias=False)一个线性映射使通道变为2倍
2.create mask部分(有点懵)
![在这里插入图片描述](https://img-blog.csdnimg.cn/ebc36327a9b84806b96d6d50c9f12dcd.png
划分窗口
相同的数字是连续的区域
代码:
h_slices = (slice(0, -self.window_size), #切片 [0,-3) 正着数是从第一个开始记为0,倒着数从最后一个开始记为-1
slice(-self.window_size, -self.shift_size),# [-3,-1)
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices: # 给区域标号
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 划分window窗口
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]窗口个数,窗口宽,高,通道数
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
# [nW, Mh*Mw, Mh*Mw] 利用广播机制
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
3.window attention
相对位置编码
整体流程(摘自博客)
增加维度,下图图示【以下这些维度操作,就很amazing!!!】
利用广播机制相减,得到相对位置编码(摘自B导视频)
如下图中颜色对应的坐标相减
这是permute变换前后变化,从横纵坐标分离 到 横纵坐标和在一起
代码:
# 相对位置编码
# get pair-wise relative position index for each token inside the window
#首先 生成绝对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) # 生成网格坐标索引 堆叠
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 并展开为2D向量
# coords_flatten[:, None, :] 在一维处插入新维度 , coords_flatten[:, :, None] 在二维处插入新维度
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw] 利用广播机制 就是通过相减得到他们的相对位置关系
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2] 调换位置
#把二元索引变成一元索引
relative_coords[:, :, 0] += self.window_size[0] - 1 # 坐标转换为从0开始
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #行坐标乘(2M-1)
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw] 最后一个维度求和
self.register_buffer("relative_position_index", relative_position_index) #注册为不参与网络学习的变量,
# #作用是根据最终的相对位置索引 找到对应的可学习的相对位置编码
以上是关于Swin Transformer代码实现部分细节重点的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 笔记: Swin-Transformer 代码
Swin Transformer v2实战:使用Swin Transformer v2实现图像分类
计算机视觉算法——Vision Transformer / Swin Transformer
Swin Transformer实战:使用 Swin Transformer实现图像分类。