PyTorch实现SqueezeNet的Fire模块
Posted 算法与编程之美
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch实现SqueezeNet的Fire模块相关的知识,希望对你有一定的参考价值。
问题
SqueezeNet是一款非常经典的CV网络,其设计理念对后续的很多网络都有非常强的指导意义,其核心思想包括:
- 使用1x1卷积核替代3x3,主要原因是3x3的卷积核参数量是1x1的9倍多;
- 降低3x3卷积核的通道数量;
- 网络结构中延迟下采样的时机以获得较大尺寸的激活特征图;
方法
下面介绍PyTorch实现的SqueezeNet网络最核心的Fire模块,如下:
import torch
from torch import nn, Tensor
from typing import Any
class BasicConv2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: # 增加in_xxx和out_xxx的好处是,调用的时候可以省略参数名
super().__init__()
self.conv2d = nn.Conv2d(in_channels, out_channels, **kwargs) # **容易漏掉
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
x = self.conv2d(x)
out = self.relu(x)
return out
class Fire(nn.Module):
def __init__(self, in_channels: int, s_1x1: int, e_1x1: int, e_3x3: int) -> None:
super().__init__()
self.squeeze = BasicConv2d(in_channels, s_1x1, kernel_size=1)
self.expand_1x1 = BasicConv2d(s_1x1, e_1x1, kernel_size = 1)
self.expand_3x3 = BasicConv2d(s_1x1, e_3x3, kernel_size = 3, padding = 1) # p=1是为了保持3x3特征图不变
def forward(self, x: Tensor) -> Tensor:
x = self.squeeze(x)
return torch.cat([
self.expand_1x1(x),
self.expand_3x3(x)
], dim=1)
if __name__ == '__main__':
x = torch.rand(size=(1, 3, 224, 224))
conv2d = BasicConv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
print(conv2d(x).shape) # torch.Size([1, 64, 224, 224])
fire = Fire(3, 32, 32, 48)
print(fire(x).shape) # torch.Size([1, 80, 224, 224])
结语
以上是关于PyTorch实现SqueezeNet的Fire模块的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch TextCNN实现中文文本分类(附完整训练代码)
轻量化卷积神经网络模型总结by wilson(shffleNet,moblieNet,squeezeNet+Xception)