PyTorch实现SqueezeNet的Fire模块

Posted 算法与编程之美

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch实现SqueezeNet的Fire模块相关的知识,希望对你有一定的参考价值。

问题

SqueezeNet是一款非常经典的CV网络,其设计理念对后续的很多网络都有非常强的指导意义,其核心思想包括:

  • 使用1x1卷积核替代3x3,主要原因是3x3的卷积核参数量是1x1的9倍多;
  • 降低3x3卷积核的通道数量;
  • 网络结构中延迟下采样的时机以获得较大尺寸的激活特征图;

方法

下面介绍PyTorch实现的SqueezeNet网络最核心的Fire模块,如下:

import torch
from torch import nn, Tensor
from typing import Any

class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: # 增加in_xxx和out_xxx的好处是,调用的时候可以省略参数名
        super().__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, **kwargs) # **容易漏掉
        self.relu = nn.ReLU()
        
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2d(x)
        out = self.relu(x)
        
        return out

class Fire(nn.Module):
    
    def __init__(self, in_channels: int, s_1x1: int, e_1x1: int, e_3x3: int) -> None:
        super().__init__()
        
        self.squeeze = BasicConv2d(in_channels, s_1x1, kernel_size=1)
        
        self.expand_1x1 = BasicConv2d(s_1x1, e_1x1, kernel_size = 1)
        self.expand_3x3 = BasicConv2d(s_1x1, e_3x3, kernel_size = 3, padding = 1) # p=1是为了保持3x3特征图不变
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.squeeze(x)
        
        return torch.cat([
            self.expand_1x1(x), 
            self.expand_3x3(x)
        ], dim=1)


if __name__ == '__main__':
    
    x = torch.rand(size=(1, 3, 224, 224))
    
    conv2d = BasicConv2d(3,  64, kernel_size = 3, padding = 1, stride = 1)
    print(conv2d(x).shape) # torch.Size([1, 64, 224, 224])   
    
    fire = Fire(3, 32, 32, 48)
    print(fire(x).shape) # torch.Size([1, 80, 224, 224])
    
    

结语

以上是关于PyTorch实现SqueezeNet的Fire模块的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch TextCNN实现中文文本分类(附完整训练代码)

轻量化卷积神经网络模型总结by wilson(shffleNet,moblieNet,squeezeNet+Xception)

libtorch(pytorch c++)教程

libtorch(pytorch c++)教程

超轻量级网络SqueezeNet网络解读

SqueezeNet模型参数降低50倍,压缩461倍