Swin Transformer代码实现部分细节重点

Posted weixin_44040169

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Swin Transformer代码实现部分细节重点相关的知识,希望对你有一定的参考价值。

swin transformer

1.patch-merging部分

代码:【amazing】

		x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 1 的位置
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 3 的位置
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 2 的位置
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 4 的位置
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C] 拼在一起,通道变为4倍

		x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]  self.reduction = nn.Linear(4*dim, 2*dim, bias=False)一个线性映射使通道变为2倍

2.create mask部分(有点懵)
![在这里插入图片描述](https://img-blog.csdnimg.cn/ebc36327a9b84806b96d6d50c9f12dcd.png
划分窗口

相同的数字是连续的区域
代码:

		h_slices = (slice(0, -self.window_size), #切片 [0,-3) 正着数是从第一个开始记为0,倒着数从最后一个开始记为-1
                    slice(-self.window_size, -self.shift_size),# [-3,-1)
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices: # 给区域标号
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
    # 划分window窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]窗口个数,窗口宽,高,通道数
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw] 利用广播机制
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

3.window attention
相对位置编码
整体流程(摘自博客)

增加维度,下图图示【以下这些维度操作,就很amazing!!!】

利用广播机制相减,得到相对位置编码(摘自B导视频)
如下图中颜色对应的坐标相减

这是permute变换前后变化,从横纵坐标分离 到 横纵坐标和在一起

代码:

 # 相对位置编码
        # get pair-wise relative position index for each token inside the window
        #首先 生成绝对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])   # 生成网格坐标索引    堆叠
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw] 并展开为2D向量
        # coords_flatten[:, None, :] 在一维处插入新维度  , coords_flatten[:, :, None] 在二维处插入新维度
                                    # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]  利用广播机制 就是通过相减得到他们的相对位置关系
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2] 调换位置
        #把二元索引变成一元索引
        relative_coords[:, :, 0] += self.window_size[0] - 1  # 坐标转换为从0开始
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #行坐标乘(2M-1)
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw] 最后一个维度求和
        self.register_buffer("relative_position_index", relative_position_index) #注册为不参与网络学习的变量,
                                                    # #作用是根据最终的相对位置索引 找到对应的可学习的相对位置编码

以上是关于Swin Transformer代码实现部分细节重点的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 笔记: Swin-Transformer 代码

Swin Transformer v2实战:使用Swin Transformer v2实现图像分类

计算机视觉算法——Vision Transformer / Swin Transformer

Swin Transformer实战:使用 Swin Transformer实现图像分类。

Swin Transformer实战:使用 Swin Transformer实现图像分类。

Swin Transformer v2实战:使用Swin Transformer v2实现图像分类