Pytorch CIFAR10图像分类 MobileNet v1篇

Posted 风信子的猫Redamancy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch CIFAR10图像分类 MobileNet v1篇相关的知识,希望对你有一定的参考价值。

Pytorch CIFAR10图像分类 MobileNet v1篇

文章目录

4.定义网络(MobileNet v1)

在之前的文章中讲的AlexNet、VGG、GoogLeNet以及ResNet网络,它们都是传统卷积神经网络(都是使用的传统卷积层),缺点在于内存需求大、运算量大导致无法在移动设备以及嵌入式设备上运行。而本文要讲的MobileNet网络就是专门为移动端,嵌入式端而设计。

我也看了论文,如果想仔细研究一下MobileNet的话,可以看我的另一篇博客【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

MobileNet网络是由google团队在2017年提出的,专注于移动端或者嵌入式设备中的轻量级CNN网络。相比传统卷积神经网络,在准确率小幅降低的前提下大大减少模型参数与运算量。(相比VGG16准确率减少了0.9%,但模型参数只有VGG的1/32)。

要说MobileNet网络的优点,无疑是其中的Depthwise Convolution结构(大大减少运算量和参数数量)。下图展示了传统卷积与DW卷积的差异,在传统卷积中,每个卷积核的channel与输入特征矩阵的channel相等(每个卷积核都会与输入特征矩阵的每一个维度进行卷积运算)。而在DW卷积中,每个卷积核的channel都是等于1的(每个卷积核只负责输入特征矩阵的一个channel,故卷积核的个数必须等于输入特征矩阵的channel数,从而使得输出特征矩阵的channel数也等于输入特征矩阵的channel数)

刚刚说了使用DW卷积后输出特征矩阵的channel是与输入特征矩阵的channel相等的,如果想改变/自定义输出特征矩阵的channel,那只需要在DW卷积后接上一个PW卷积即可,如下图所示,其实PW卷积就是普通的卷积而已(只不过卷积核大小为1)。通常DW卷积和PW卷积是放在一起使用的,一起叫做Depthwise Separable Convolution(深度可分卷积)

那Depthwise Separable Convolution(深度可分卷积)与传统的卷积相比有到底能节省多少计算量呢,下图对比了这两个卷积方式的计算量,其中Df是输入特征矩阵的宽高(这里假设宽和高相等),Dk是卷积核的大小,M是输入特征矩阵的channel,N是输出特征矩阵的channel,卷积计算量近似等于卷积核的高 x 卷积核的宽 x 卷积核的channel x 输入特征矩阵的高 x 输入特征矩阵的宽(这里假设stride等于1),在我们mobilenet网络中DW卷积都是是使用3x3大小的卷积核。所以理论上普通卷积计算量是DW+PW卷积的8到9倍(公式来源于原论文):

在了解完Depthwise Separable Convolution(深度可分卷积)后在看下mobilenet v1的网络结构,左侧的表格是mobileNetv1的网络结构,表中标Conv的表示普通卷积,Conv dw代表刚刚说的DW卷积,s表示步距,根据表格信息就能很容易的搭建出mobileNet v1网络。在mobilenetv1原论文中,还提出了两个超参数,一个是α一个是β。α参数是一个倍率因子,用来调整卷积核的个数,β是控制输入网络的图像尺寸参数,下图右侧给出了使用不同α和β网络的分类准确率,计算量以及模型参数:

首先我们还是得判断是否可以利用GPU,因为GPU的速度可能会比我们用CPU的速度快20-50倍左右,特别是对卷积神经网络来说,更是提升特别明显。

device = 'cuda' if torch.cuda.is_available() else 'cpu'

接着我们可以定义网络,在pytorch之中,定义我们的深度可分离卷积来说,我们需要调一个groups参数,就可以构建深度可分离卷积了。

class Block(nn.Module):
    '''Depthwise conv + Pointwise conv'''
    def __init__(self,in_channels,out_channels,stride=1):
        super(Block,self).__init__()
        # groups参数就是深度可分离卷积的关键
        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=stride,
                               padding=1,groups=in_channels,bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()
    def forward(self,x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        return x
        
# 深度可分离卷积 DepthWise Separable Convolution
class MobileNetV1(nn.Module):
    # (128,2) means conv channel=128, conv stride=2, by default conv stride=1
    cfg = [64,(128,2),128,(256,2),256,(512,2),512,512,512,512,512,(1024,2),1024]
    
    def __init__(self, num_classes=10,alpha=1.0,beta=1.0):
        super(MobileNetV1,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,32,kernel_size=3,stride=1,bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.avg = nn.AvgPool2d(kernel_size=2)
        self.layers = self._make_layers(in_channels=32)
        self.linear = nn.Linear(1024,num_classes)
    
    def _make_layers(self, in_channels):
        layers = []
        for x in self.cfg:
            out_channels = x if isinstance(x,int) else x[0]
            stride = 1 if isinstance(x,int) else x[1]
            layers.append(Block(in_channels,out_channels,stride))
            in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.layers(x)
        x = self.avg(x)
        x = x.view(x.size()[0],-1)
        x = self.linear(x)
        return x
net = MobileNetV1(num_classes=10).to(device)
summary(net,(2,3,32,32))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MobileNetV1                              --                        --
├─Sequential: 1-1                        [2, 32, 30, 30]           --
│    └─Conv2d: 2-1                       [2, 32, 30, 30]           864
│    └─BatchNorm2d: 2-2                  [2, 32, 30, 30]           64
│    └─ReLU: 2-3                         [2, 32, 30, 30]           --
├─Sequential: 1-2                        [2, 1024, 2, 2]           --
│    └─Block: 2-4                        [2, 64, 30, 30]           --
│    │    └─Conv2d: 3-1                  [2, 32, 30, 30]           288
│    │    └─BatchNorm2d: 3-2             [2, 32, 30, 30]           64
│    │    └─ReLU: 3-3                    [2, 32, 30, 30]           --
│    │    └─Conv2d: 3-4                  [2, 64, 30, 30]           2,048
│    │    └─BatchNorm2d: 3-5             [2, 64, 30, 30]           128
│    │    └─ReLU: 3-6                    [2, 64, 30, 30]           --
│    └─Block: 2-5                        [2, 128, 15, 15]          --
│    │    └─Conv2d: 3-7                  [2, 64, 15, 15]           576
│    │    └─BatchNorm2d: 3-8             [2, 64, 15, 15]           128
│    │    └─ReLU: 3-9                    [2, 64, 15, 15]           --
│    │    └─Conv2d: 3-10                 [2, 128, 15, 15]          8,192
│    │    └─BatchNorm2d: 3-11            [2, 128, 15, 15]          256
│    │    └─ReLU: 3-12                   [2, 128, 15, 15]          --
│    └─Block: 2-6                        [2, 128, 15, 15]          --
│    │    └─Conv2d: 3-13                 [2, 128, 15, 15]          1,152
│    │    └─BatchNorm2d: 3-14            [2, 128, 15, 15]          256
│    │    └─ReLU: 3-15                   [2, 128, 15, 15]          --
│    │    └─Conv2d: 3-16                 [2, 128, 15, 15]          16,384
│    │    └─BatchNorm2d: 3-17            [2, 128, 15, 15]          256
│    │    └─ReLU: 3-18                   [2, 128, 15, 15]          --
│    └─Block: 2-7                        [2, 256, 8, 8]            --
│    │    └─Conv2d: 3-19                 [2, 128, 8, 8]            1,152
│    │    └─BatchNorm2d: 3-20            [2, 128, 8, 8]            256
│    │    └─ReLU: 3-21                   [2, 128, 8, 8]            --
│    │    └─Conv2d: 3-22                 [2, 256, 8, 8]            32,768
│    │    └─BatchNorm2d: 3-23            [2, 256, 8, 8]            512
│    │    └─ReLU: 3-24                   [2, 256, 8, 8]            --
│    └─Block: 2-8                        [2, 256, 8, 8]            --
│    │    └─Conv2d: 3-25                 [2, 256, 8, 8]            2,304
│    │    └─BatchNorm2d: 3-26            [2, 256, 8, 8]            512
│    │    └─ReLU: 3-27                   [2, 256, 8, 8]            --
│    │    └─Conv2d: 3-28                 [2, 256, 8, 8]            65,536
│    │    └─BatchNorm2d: 3-29            [2, 256, 8, 8]            512
│    │    └─ReLU: 3-30                   [2, 256, 8, 8]            --
│    └─Block: 2-9                        [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-31                 [2, 256, 4, 4]            2,304
│    │    └─BatchNorm2d: 3-32            [2, 256, 4, 4]            512
│    │    └─ReLU: 3-33                   [2, 256, 4, 4]            --
│    │    └─Conv2d: 3-34                 [2, 512, 4, 4]            131,072
│    │    └─BatchNorm2d: 3-35            [2, 512, 4, 4]            1,024
│    │    └─ReLU: 3-36                   [2, 512, 4, 4]            --
│    └─Block: 2-10                       [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-37                 [2, 512, 4, 4]            4,608
│    │    └─BatchNorm2d: 3-38            [2, 512, 4, 4]            1,024
│    │    └─ReLU: 3-39                   [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-40                 [2, 512, 4, 4]            262,144
│    │    └─BatchNorm2d: 3-41            [2, 512, 4, 4]            1,024
│    │    └─ReLU: 3-42                   [2, 512, 4, 4]            --
│    └─Block: 2-11                       [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-43                 [2, 512, 4, 4]            4,608
│    │    └─BatchNorm2d: 3-44            [2, 512, 4, 4]            1,024
│    │    └─ReLU: 3-45                   [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-46                 [2, 512, 4, 4]            262,144
│    │    └─BatchNorm2d: 3-47            [2, 512, 4, 4]            1,024
│    │    └─ReLU: 3-48                   [2, 512, 4, 4]            --
│    └─Block: 2-12                       [2, 512, 4, 4]            --
│    │    └─Conv2d: 3-49                 [2, 512, 4, 4]            Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNet篇

Pytorch CIFAR10图像分类 EfficientNet v1篇

Pytorch CIFAR10图像分类 EfficientNet v1篇

Pytorch CIFAR10图像分类 EfficientNet v1篇