Pytorch一文搞懂nn.Conv2d的groups参数的作用
Posted SinHao22
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch一文搞懂nn.Conv2d的groups参数的作用相关的知识,希望对你有一定的参考价值。
目录
1. 语言描述
在Pytorch1.13的官方文档中,关于nn.Conv2d中的groups的作用是这么描述的:
简单来说就是将输入和输出的通道(channel)进行分组,每一组单独进行卷积操作,然后再把结果拼接(concat)起来。
比如输入大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出大小为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5), g r o u p s = 2 groups=2 groups=2。就是将输入的4个channel分成2个2的channel,输出的8个channel分成2个4的channel,每个输入的2个channel和输出的4个channel组成一组,每组做完卷积后的输出大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)。然后把得到的两组输出在channel这个维度上进行concat,得到最后的输出维度为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)。
但其实这么描述理解起来不够直观,下面我举个例子,先从语言上进行详细的解释,然后再进行代码验证。
符号 | 数值 | 含义 |
---|---|---|
i n p u t _ c h a n n e l input\\_channel input_channel | 4 | 输入通道数量 |
o n p u t _ c h a n n e l onput\\_channel onput_channel | 8 | 输出通道数量,其实就是卷积核的个数,我们将其看作卷积核的个数会更容易理解 |
b a t c h _ s i z e batch\\_size batch_size | 1 | 批量大小为1 |
H , W H, W H,W | 5 | 输入输出的feature大小为5x5 |
i n p u t _ s h a p e input\\_shape input_shape | ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5) | 输入的shape,注意我们这里设置输入的所有元素都为1,即输入是一个全1的tensor |
o u t p u t _ s h a p e output\\_shape output_shape | ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5) | 输出的shape |
k e r n e l _ s i z e kernel\\_size kernel_size | 3 | 卷积核的大小为3x3 |
p a d d i n g padding padding | 1 | 填充长度为1,这里我们使用1填充(即周围补一圈1),而不是0填充 |
s t r i d e stride stride | 1 | 步长为1 |
我们假设输入tensor的shape为
(
1
,
4
,
5
,
5
)
(1, 4, 5, 5)
(1,4,5,5),输出tensor的shape为:
(
1
,
8
,
5
,
5
)
(1, 8, 5, 5)
(1,8,5,5),即我们的卷积核有8个。下面的图由于
b
a
t
c
h
_
s
i
z
e
=
1
batch\\_size=1
batch_size=1,所以省略的
b
a
t
c
h
_
s
i
z
e
batch\\_size
batch_size的维度。
值得注意的是,这里我们手动设置卷积核中元素的值,前4个卷积核的值都设置为1,后4个卷积核的值都设置为2,如下图所示:
这里解释一下为什么
g
r
o
u
p
s
=
1
groups=1
groups=1时
k
e
r
n
e
l
_
s
i
z
e
=
(
4
,
3
,
3
)
kernel\\_size=(4, 3, 3)
kernel_size=(4,3,3),
g
r
o
u
p
s
=
2
groups=2
groups=2时
k
e
r
n
e
l
_
s
i
z
e
=
(
2
,
3
,
3
)
kernel\\_size=(2, 3, 3)
kernel_size=(2,3,3):因为
g
r
o
u
p
s
=
2
groups=2
groups=2时,输入和输出都被分成了两组,输入的shape原来为:
(
4
,
5
,
5
)
(4, 5, 5)
(4,5,5),被分成了两个
(
2
,
5
,
5
)
(2, 5, 5)
(2,5,5),所以每个
k
e
r
n
e
l
_
s
i
z
e
kernel\\_size
kernel_size也由
(
4
,
3
,
3
)
(4, 3, 3)
(4,3,3)变为
(
2
,
3
,
3
)
(2, 3, 3)
(2,3,3)。
下面我们来看一下 g r o u p s = 1 groups=1 groups=1和 g r o u p s = 2 groups=2 groups=2时计算过程的不同:
【情况1:groups=1】
此时就和正常卷积一样:
这里解释一下:output的前4个channel的每个feature map的所有元素都为36,后4个channel的每个feature map的所有元素都为72,这是因为:
每个输入的
H
,
W
H,W
H,W是5x5,加上padding之后是6x6,具体过程如下:
【情况1:groups=2】
此时应当这么算:
为什么output的前4个channel的每个feature map的所有元素都为18,后4个channel的每个feature map的所有元素都为36呢?看了下面的图应该就能理解这个过程了:
2. 代码验证:
实验环境:Python3.7,torch1.10.2
代码:
import os
import torch
import torch.nn as nn
if __name__ == '__main__':
input_dim, output_dim = 4, 8
X = torch.ones(1, input_dim, 5, 5)
# groups = 1
conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=1, bias=False, padding_mode='replicate')
print(f'groups=1时,卷积核的形状为:conv1.weight.shape')
with torch.no_grad():
conv1.weight[:4, :, :, :] = torch.ones(4, 4, 3, 3)
conv1.weight[4:, :, :, :] = torch.ones(4, 4, 3, 3) * 2
Y1 = conv1(X)
print(f'结果为:\\nY1')
# groups = 2
conv2 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=2, bias=False, padding_mode='replicate')
print(f'groups=2时,卷积核的形状为:conv2.weight.shape')
with torch.no_grad():
conv2.weight[:4, :, :, :] = torch.ones(4, 2, 3, 3)
conv2.weight[4:, :, :, :] = torch.ones(4, 2, 3, 3) * 2
Y2 = conv2(X)
print(f'结果为:\\nY2')
结果:
groups=1时,卷积核的形状为:torch.Size([8, 4, 3, 3])
结果为:
tensor([[[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]],
[[72., 72., 72., 72., 72.],
以上是关于Pytorch一文搞懂nn.Conv2d的groups参数的作用的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch重要函数(nn.Conv2d;nn.ConvTranspose2d)