PyTorch笔记 - Convolution卷积运算的原理

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch笔记 - Convolution卷积运算的原理相关的知识,希望对你有一定的参考价值。

DilatedConv、GroupConv,膨胀卷积、组卷积,源码:torch.nn.Conv2d

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

当dilation > 1时,卷积核不紧凑

import torch
import torch.nn as nn
import torch.nn.functional as F

a = torch.randn(7, 7)
print(f"a: a")
a[0:3, 0:3]  # dilation=1
a[0:5:2, 0:5:2]  # dilation=2
a[0:7:3, 0:7:3]  # dilation=3
# a[0:dx2+1:d]

当group>1时,分组卷积,再进行合

# group convolution
in_channel, out_channel = 2, 4
# kernel:2x4,8个卷积核
groups = 2
# 每组有2个卷积组,一个4个卷积核,输入和输出大小不变 
sub_in_channel, sub_out_channel = 1, 2  
# 把结果拼起来,通道融合并不充分,只需要在每个group内进行融合,最后拼接
# 再使用1x1卷积,进行通道融合 1x1 point-wise convolution

Convolution with DilatedConv and GroupConv:

import torch
import torch.nn as nn
import torch.nn.functional as F


def matrix_multiplication_for_conv2d_final(input, kernel, bias=0, stride=1, padding=0, dilation=1, groups=1):
    if padding > 0:
        # 从里到外,width、height、channel、batch
        input = F.pad(input, (padding, padding, padding, padding, 0, 0, 0, 0))  
        
    bs, in_channel, input_h, input_w = input.shape
    # kernel一共4维,包含通道融合的功能
    out_channel, _, kernel_h, kernel_w = kernel.shape
    
    assert out_channel%groups==0 and in_channel%groups==0, "groups必须要同时被输入通道和输出通道数整除!"
    input = input.reshape((bs, groups, in_channel//groups, input_h, input_w))
    kernel = kernel.reshape((groups, out_channel//groups, in_channel//groups, kernel_h, kernel_w))
    
    kernel_h = (kernel_h-1)*(dilation-1) + kernel_h  # 例如k=3,d=2,new_k=2x1+3=5
    kernel_w = (kernel_w-1)*(dilation-1) + kernel_w
    
    
    if bias is None:
        bias = torch.zeros(out_channel)
        
    # 向下取整floor, 直接pad到input,不用padding
    output_h = (input_h - kernel_h) // stride + 1  # 卷积输出的高度
    output_w = (input_w - kernel_w) // stride + 1  # 卷积输出的宽度
    output = torch.zeros((bs, groups, out_channel//groups, output_h, output_w))  # 初始化输出矩阵
    
    for ind in range(bs):  # 对batchsize进行遍历
        for g in range(groups):  # 对群组进行遍历
            for oc in range(out_channel//groups):  # 对分组后的输出通道进行遍历
                for ic in range(in_channel//groups):  # 对分组后的输入通道进行遍历
                    for i in range(0, input_h-kernel_h+1, stride):  # 对高度维进行遍历,input_h已经包括padding
                        for j in range(0, input_w-kernel_w+1, stride):  # 对宽度度维进行遍历
                            region = input[ind, g, ic, i:i+kernel_h:dilation, j:j+kernel_w:dilation]
                            # 点乘,并且赋值输出位置的元素
                            output[ind, g, oc, i//stride, j//stride] += torch.sum(region * kernel[g, oc, ic])  
                output[ind, g, oc] += bias[g*(out_channel//groups) + oc]
    output = output.reshape((bs, out_channel, output_h, output_w))
    return output

# 以下为验证和测试的代码,验证与函数PyTorch API结果是否一致
bs, in_channel, input_h, input_w = 2, 2, 5, 5
kernel_size = 3
out_channel = 4
groups, dilation, stride, padding = 2, 2, 2, 1
input = torch.randn((bs, in_channel, input_h, input_w))
kernel = torch.randn((out_channel, in_channel//groups, kernel_size, kernel_size))
bias = torch.randn(out_channel)

# PyTorch的官方API
pytorch_conv2d_api_output = F.conv2d(input, kernel, bias=bias, padding=padding, \\
                                     stride=stride, dilation=dilation, groups=groups)
mm_conv2d_final_output = matrix_multiplication_for_conv2d_final(input, kernel, bias=bias, padding=padding, \\
                                     stride=stride, dilation=dilation, groups=groups)
print(f"pytorch_conv2d_api_output: pytorch_conv2d_api_output")
print(f"mm_conv2d_final_output: mm_conv2d_final_output")
flag = torch.allclose(pytorch_conv2d_api_output, mm_conv2d_final_output)
print(f"flag: flag")

以上是关于PyTorch笔记 - Convolution卷积运算的原理的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch笔记 - Convolution卷积运算的原理

PyTorch笔记 - Convolution卷积运算的原理

PyTorch笔记 - Convolution卷积运算的原理

PyTorch笔记 - Convolution卷积运算的原理

PyTorch笔记 - Convolution卷积运算的原理

PyTorch笔记 - Convolution卷积运算的原理