Pytorch一文搞懂nn.Conv2d的groups参数的作用

Posted SinHao22

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch一文搞懂nn.Conv2d的groups参数的作用相关的知识,希望对你有一定的参考价值。

目录

1. 语言描述

在Pytorch1.13的官方文档中,关于nn.Conv2d中的groups的作用是这么描述的:

简单来说就是将输入和输出的通道(channel)进行分组,每一组单独进行卷积操作,然后再把结果拼接(concat)起来。

比如输入大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出大小为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5) g r o u p s = 2 groups=2 groups=2。就是将输入的4个channel分成2个2的channel,输出的8个channel分成2个4的channel,每个输入的2个channel和输出的4个channel组成一组,每组做完卷积后的输出大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)。然后把得到的两组输出在channel这个维度上进行concat,得到最后的输出维度为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)

但其实这么描述理解起来不够直观,下面我举个例子,先从语言上进行详细的解释,然后再进行代码验证。

符号数值含义
i n p u t _ c h a n n e l input\\_channel input_channel4输入通道数量
o n p u t _ c h a n n e l onput\\_channel onput_channel8输出通道数量,其实就是卷积核的个数,我们将其看作卷积核的个数会更容易理解
b a t c h _ s i z e batch\\_size batch_size1批量大小为1
H , W H, W H,W5输入输出的feature大小为5x5
i n p u t _ s h a p e input\\_shape input_shape ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输入的shape,注意我们这里设置输入的所有元素都为1,即输入是一个全1的tensor
o u t p u t _ s h a p e output\\_shape output_shape ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)输出的shape
k e r n e l _ s i z e kernel\\_size kernel_size3卷积核的大小为3x3
p a d d i n g padding padding1填充长度为1,这里我们使用1填充(即周围补一圈1),而不是0填充
s t r i d e stride stride1步长为1

我们假设输入tensor的shape为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输出tensor的shape为: ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5),即我们的卷积核有8个。下面的图由于 b a t c h _ s i z e = 1 batch\\_size=1 batch_size=1,所以省略的 b a t c h _ s i z e batch\\_size batch_size的维度。

值得注意的是,这里我们手动设置卷积核中元素的值,前4个卷积核的值都设置为1,后4个卷积核的值都设置为2,如下图所示:


这里解释一下为什么 g r o u p s = 1 groups=1 groups=1 k e r n e l _ s i z e = ( 4 , 3 , 3 ) kernel\\_size=(4, 3, 3) kernel_size=(4,3,3) g r o u p s = 2 groups=2 groups=2 k e r n e l _ s i z e = ( 2 , 3 , 3 ) kernel\\_size=(2, 3, 3) kernel_size=(2,3,3):因为 g r o u p s = 2 groups=2 groups=2时,输入和输出都被分成了两组,输入的shape原来为: ( 4 , 5 , 5 ) (4, 5, 5) (4,5,5),被分成了两个 ( 2 , 5 , 5 ) (2, 5, 5) (2,5,5),所以每个 k e r n e l _ s i z e kernel\\_size kernel_size也由 ( 4 , 3 , 3 ) (4, 3, 3) (4,3,3)变为 ( 2 , 3 , 3 ) (2, 3, 3) (2,3,3)

下面我们来看一下 g r o u p s = 1 groups=1 groups=1 g r o u p s = 2 groups=2 groups=2时计算过程的不同:

【情况1:groups=1】
此时就和正常卷积一样:

这里解释一下:output的前4个channel的每个feature map的所有元素都为36,后4个channel的每个feature map的所有元素都为72,这是因为:
每个输入的 H , W H,W H,W是5x5,加上padding之后是6x6,具体过程如下:

【情况1:groups=2】
此时应当这么算:

为什么output的前4个channel的每个feature map的所有元素都为18,后4个channel的每个feature map的所有元素都为36呢?看了下面的图应该就能理解这个过程了:

2. 代码验证:

实验环境:Python3.7,torch1.10.2
代码:

import os

import torch
import torch.nn as nn


if __name__ == '__main__':
    input_dim, output_dim = 4, 8
    X = torch.ones(1, input_dim, 5, 5)

    # groups = 1
    conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=1, bias=False, padding_mode='replicate')
    print(f'groups=1时,卷积核的形状为:conv1.weight.shape')
    with torch.no_grad():
        conv1.weight[:4, :, :, :] = torch.ones(4, 4, 3, 3)
        conv1.weight[4:, :, :, :] = torch.ones(4, 4, 3, 3) * 2
        Y1 = conv1(X)
        print(f'结果为:\\nY1')

    # groups = 2
    conv2 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=2, bias=False, padding_mode='replicate')
    print(f'groups=2时,卷积核的形状为:conv2.weight.shape')
    with torch.no_grad():
        conv2.weight[:4, :, :, :] = torch.ones(4, 2, 3, 3)
        conv2.weight[4:, :, :, :] = torch.ones(4, 2, 3, 3) * 2
        Y2 = conv2(X)
        print(f'结果为:\\nY2')


结果:

groups=1时,卷积核的形状为:torch.Size([8, 4, 3, 3])
结果为:
tensor([[[[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.],
          [36., 36., 36., 36., 36.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]],

         [[72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.],
          [72., 72., 72., 72., 72.]],

         [[72., 72., 72., 72., 72.],
          以上是关于Pytorch一文搞懂nn.Conv2d的groups参数的作用的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 笔记:torch.nn.Conv2d

Pytorch重要函数(nn.Conv2d;nn.ConvTranspose2d)

PyTorch网络搭建中*list的用法解析

[Pytorch系列-31]:卷积神经网络 - torch.nn.Conv2d() 用法详解

pytorch(网络模型)

pytorch(网络模型)