PyTorch基础(13)-- torch.nn.Unfold()方法

Posted 奋斗丶

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch基础(13)-- torch.nn.Unfold()方法相关的知识,希望对你有一定的参考价值。

前言

最近在看新论文的过程中,发现新论文中的代码非常简洁,只用了unfold和fold方法便高效的将论文的思想表达出,因此学习记录一下unfold和fold方法。

一、方法详解

  • 方法
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
  • parameters
    • kernel_size (int or tuple) – 滑动窗口的size

    • stride (int or tuple, optional) – 空间维度上滑动的步长,默认步长为1

    • padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0

    • dilation (int or tuple, optional) – 空洞卷积的扩充率,默认为1

  • 释义:提取滑动窗口滑过的所有值,例如下面的例子中,
[[ 0.4009,  0.6350, -0.5197,  0.8148, -0.7235],
[-1.2102,  0.4621, -0.3421, -0.9261, -2.8376],
[-1.5553,  0.1713,  0.6820, -2.0880, -0.0204],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108],
[ 0.1459, -0.4568,  1.0039, -1.2385, -1.4467]]

kernel size =3 的窗口滑过,会首先记录

[[ 0.4009,  0.6350, -0.5197, -1.2102,  0.4621, -0.3421, -1.5553, 0.1713,  0.6820],
 [ 0.6350, -0.5197,  0.8148,  0.4621, -0.3421, -0.9261,  0.1713, 0.6820, -2.0880],
 [-0.5197,  0.8148, -0.7235, -0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204],
 [-1.2102,  0.4621, -0.3421, -1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510],
 [ 0.4621, -0.3421, -0.9261,  0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367],
 [-0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108],
 [-1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510,  0.1459, -0.4568,  1.0039],
 [ 0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568, 1.0039, -1.2385],
 [ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108,  1.0039, -1.2385, -1.4467]]
  • Note:unfold方法的输入只能是4维的,即(N,C,H,W)

二、如何计算输出的size

  • 栗子
import torch
import torch.nn as nn
if __name__ == '__main__':
    x = torch.randn(2, 3, 5, 5)
    print(x)
    unfold = nn.Unfold(2)
    y = unfold(x)
    print(y.size())
    print(y)
  • 运行结果
torch.Size([2, 12, 16])

接下来,我们一步一步分析这个结果是怎么计算出来的!

首先,要知道的是,我们的输入必须是4维的,即(B,C,H,W),其中,B表示Batch size;C代表通道数;H代表feature map的高;W表示feature map的宽。首先,我们假设经过Unfolder处理之后的size为(B,h,w)。然后我们需要计算h(即输出的高),计算公式如下所示:
在这里插入图片描述

这里是引用举个栗子:假设输入通道数为3,kernel size为(2,2),图片最常见的通道数为3(所以我们拿来举例),经过Unfolder方法后,输出的高变为322=12,即输出的H为12。

计算完成之后,我们需要计算w,计算公式如下所示:
在这里插入图片描述
其中,d代表的是空间的所有维度数,例如空间维度为(H,W),则d=2。下面通过举例,我们来计算输出的w。

举个栗子:如果输入的H、W分别为5,kernel size为2,则输出的w为
在这里插入图片描述
4*4=16,故最终的输出size为[2,12,16]。

三、案例

  • 案例
import torch
import torch.nn as nn
if __name__ == '__main__':
    x = torch.randn(1, 3, 5, 5)
    print(x)
    unfold = nn.Unfold(kernel_size=3)
    output = unfold(x)
    print(output, output.size())
  • 运行结果
tensor([[[[ 0.4009,  0.6350, -0.5197,  0.8148, -0.7235],
          [-1.2102,  0.4621, -0.3421, -0.9261, -2.8376],
          [-1.5553,  0.1713,  0.6820, -2.0880, -0.0204],
          [ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108],
          [ 0.1459, -0.4568,  1.0039, -1.2385, -1.4467]],

         [[-0.9973, -0.7601, -0.2161,  1.2120, -0.3036],
          [-0.7279,  0.0833, -0.8886, -0.9168,  0.7503],
          [-0.6748,  0.7064,  0.6903, -1.0447,  0.8688],
          [-0.5230, -1.2308, -0.3932,  1.2521, -0.2523],
          [-0.3930,  0.6452,  0.1690,  0.3744,  0.2015]],

         [[ 0.6403,  1.3915, -1.9529,  0.2899, -0.8897],
          [-0.1720,  1.0843, -1.0177, -1.7480, -0.5217],
          [-0.9648, -0.0867, -0.2926,  0.3010,  0.3192],
          [ 0.1181, -0.2218,  0.0766,  0.5914, -0.8932],
          [-0.4508, -0.3964,  1.1163,  0.6776, -0.8948]]]])
tensor([[[ 0.4009,  0.6350, -0.5197, -1.2102,  0.4621, -0.3421, -1.5553,
           0.1713,  0.6820],
         [ 0.6350, -0.5197,  0.8148,  0.4621, -0.3421, -0.9261,  0.1713,
           0.6820, -2.0880],
         [-0.5197,  0.8148, -0.7235, -0.3421, -0.9261, -2.8376,  0.6820,
          -2.0880, -0.0204],
         [-1.2102,  0.4621, -0.3421, -1.5553,  0.1713,  0.6820,  1.1419,
          -0.4881, -0.9510],
         [ 0.4621, -0.3421, -0.9261,  0.1713,  0.6820, -2.0880, -0.4881,
          -0.9510, -0.0367],
         [-0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204, -0.9510,
          -0.0367, -0.8108],
         [-1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510,  0.1459,
          -0.4568,  1.0039],
         [ 0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568,
           1.0039, -1.2385],
         [ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108,  1.0039,
          -1.2385, -1.4467],
         [-0.9973, -0.7601, -0.2161, -0.7279,  0.0833, -0.8886, -0.6748,
           0.7064,  0.6903],
         [-0.7601, -0.2161,  1.2120,  0.0833, -0.8886, -0.9168,  0.7064,
           0.6903, -1.0447],
         [-0.2161,  1.2120, -0.3036, -0.8886, -0.9168,  0.7503,  0.6903,
          -1.0447,  0.8688],
         [-0.7279,  0.0833, -0.8886, -0.6748,  0.7064,  0.6903, -0.5230,
          -1.2308, -0.3932],
         [ 0.0833, -0.8886, -0.9168,  0.7064,  0.6903, -1.0447, -1.2308,
          -0.3932,  1.2521],
         [-0.8886, -0.9168,  0.7503,  0.6903, -1.0447,  0.8688, -0.3932,
           1.2521, -0.2523],
         [-0.6748,  0.7064,  0.6903, -0.5230, -1.2308, -0.3932, -0.3930,
           0.6452,  0.1690],
         [ 0.7064,  0.6903, -1.0447, -1.2308, -0.3932,  1.2521,  0.6452,
           0.1690,  0.3744],
         [ 0.6903, -1.0447,  0.8688, -0.3932,  1.2521, -0.2523,  0.1690,
           0.3744,  0.2015],
         [ 0.6403,  1.3915, -1.9529, -0.1720,  1.0843, -1.0177, -0.9648,
          -0.0867, -0.2926],
         [ 1.3915, -1.9529,  0.2899,  1.0843, -1.0177, -1.7480, -0.0867,
          -0.2926,  0.3010],
         [-1.9529,  0.2899, -0.8897, -1.0177, -1.7480, -0.5217, -0.2926,
           0.3010,  0.3192],
         [-0.1720,  1.0843, -1.0177, -0.9648, -0.0867, -0.2926,  0.1181,
          -0.2218,  0.0766],
         [ 1.0843, -1.0177, -1.7480, -0.0867, -0.2926,  0.3010, -0.2218,
           0.0766,  0.5914],
         [-1.0177, -1.7480, -0.5217, -0.2926,  0.3010,  0.3192,  0.0766,
           0.5914, -0.8932],
         [-0.9648, -0.0867, -0.2926,  0.1181, -0.2218,  0.0766, -0.4508,
          -0.3964,  1.1163],
         [-0.0867, -0.2926,  0.3010, -0.2218,  0.0766,  0.5914, -0.3964,
           1.1163,  0.6776],
         [-0.2926,  0.3010,  0.3192,  0.0766,  0.5914, -0.8932,  1.1163,
           0.6776, -0.8948]]]) torch.Size([1, 27, 9])

觉得写的不错的话,欢迎点赞+评论+收藏,这对我帮助很大!

以上是关于PyTorch基础(13)-- torch.nn.Unfold()方法的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-28]:神经网络基础 - torch.nn模块功能列表

[Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameternn.Linearnn.functioinalnn.Modulenn.Sequentia(代码片

pytorch中的顺序容器——torch.nn.Sequential

PyTorch基础(12)-- torch.nn.BatchNorm2d()方法

PyTorch基础(12)-- torch.nn.BatchNorm2d()方法

pytorch基础