pytorch进行维度变换以及形状变换

Posted Coding With you.....

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch进行维度变换以及形状变换相关的知识,希望对你有一定的参考价值。

1.形状变换中长勇的函数是reshape和view,那么这两个函数的使用有什么不同呢

首先说相同之处:都可以将张量的形状进行变换,便于我们在不同的模块中进行使用

然后说明不同之处:view是要求张量的物理存储是连续的,如果不是连续的则会报错,当然z如果想要继续使用该函数,需要使用contiguous()函数进行转换——这时的作用就可以等价于reshape

a = torch.arange(9).reshape(3, 3)   # 初始化张量a
b1 = a.permute(1, 0)  #这个相当于是浅拷贝,只进行了复制 变换,并没有生成新的内存空间
b = a.permute(1, 0).contiguous()       # 转置,并转换为符合连续性条件的tensor    这个就是深拷贝,b不是a的原始地址  而是重新分配的地址
m=b.storage().data_ptr()
m1=b1.storage().data_ptr()
m2=m1=a.storage().data_ptr()
m!=m1=m2

因此,如果是连续的物理空间下,就使用view比较节省空间;如果是不连续的情况下,两种方法使用的效果是一样的——都会开辟新的空间

a.reshape = a.view() & a.contiguous().view() 当是有序的时等价于前者;当是无序的时等价于后者

2.进行维度变换torch.squeeze() 和torch.unsqueeze()

torch.squeeze()这个函数主要是对数据维度进行压缩,a = torch.rand(1,2,3,4,5)   # 初始化张量a
b=torch.squeeze(a)  #这里有1 2 3 4 5 降维成了2 3 4 5 去掉维数为1的维度;如果指定的维度不为1 比如torch.squeeze(a,2) ,则不作处理——则相当于是没有影响的 还是1 2 3 4 5

torch.unsqueeze(tensor,N)这个函数主要是对数据维度进行扩充,N填几就在第几个维度上进行扩充
举个例子:
a = torch.rand(1,2,3,4,5)   # 初始化张量a
c=torch.unsqueeze(a,1)  #1 1 2 3 4 5
c1=torch.unsqueeze(a,2)  #1 2 1 3 4 5
c2=torch.unsqueeze(a,3)   #1 2 3 1 4 5
c3=torch.unsqueeze(a,4) #1 2 3 4 1 5
c4=torch.unsqueeze(a,0)  #1 1 2 3 4 5
c5=torch.unsqueeze(a,5)  #1 2 3 4 5 1

3.tensor的转置

如果是二维的矩阵,可以采用torch.t或者torch.transpose进行转置

如果是高维的张量,可以使用permute()进行转置

下面是permute转置前后的结果: 

tensor([[[[0.0059, 0.7617, 0.6693, 0.3750, 0.2394],
          [0.4180, 0.3297, 0.9128, 0.4802, 0.3164],
          [0.0389, 0.5162, 0.0794, 0.6116, 0.5534],
          [0.0541, 0.6623, 0.4360, 0.9220, 0.3167]],

         [[0.2555, 0.2535, 0.1496, 0.4983, 0.7569],
          [0.3564, 0.9477, 0.4017, 0.7491, 0.2962],
          [0.6235, 0.4532, 0.9770, 0.0429, 0.1650],
          [0.8686, 0.0464, 0.5564, 0.8426, 0.4353]],

         [[0.3709, 0.5796, 0.0188, 0.9509, 0.1464],
          [0.3500, 0.8551, 0.4692, 0.0515, 0.1865],
          [0.2953, 0.7891, 0.9032, 0.4210, 0.1943],
          [0.9576, 0.9855, 0.1127, 0.0494, 0.9233]]],


        [[[0.9636, 0.1502, 0.2663, 0.8124, 0.6775],
          [0.2312, 0.1562, 0.9889, 0.9965, 0.7029],
          [0.2522, 0.8169, 0.9949, 0.9000, 0.6066],
          [0.7273, 0.1836, 0.8402, 0.4715, 0.4603]],

         [[0.2963, 0.4937, 0.5248, 0.3054, 0.8550],
          [0.2519, 0.0765, 0.2370, 0.0087, 0.7299],
          [0.1716, 0.8241, 0.5534, 0.3875, 0.7326],
          [0.5413, 0.9876, 0.7963, 0.3272, 0.4132]],

         [[0.4647, 0.3225, 0.6812, 0.3244, 0.9137],
          [0.3233, 0.7194, 0.3040, 0.1910, 0.7097],
          [0.8174, 0.7716, 0.5967, 0.8277, 0.4918],
          [0.5733, 0.0092, 0.3861, 0.2801, 0.3459]]]])

tensor([[[[0.0059, 0.7617, 0.6693, 0.3750, 0.2394],
          [0.2555, 0.2535, 0.1496, 0.4983, 0.7569],
          [0.3709, 0.5796, 0.0188, 0.9509, 0.1464]],

         [[0.4180, 0.3297, 0.9128, 0.4802, 0.3164],
          [0.3564, 0.9477, 0.4017, 0.7491, 0.2962],
          [0.3500, 0.8551, 0.4692, 0.0515, 0.1865]],

         [[0.0389, 0.5162, 0.0794, 0.6116, 0.5534],
          [0.6235, 0.4532, 0.9770, 0.0429, 0.1650],
          [0.2953, 0.7891, 0.9032, 0.4210, 0.1943]],

         [[0.0541, 0.6623, 0.4360, 0.9220, 0.3167],
          [0.8686, 0.0464, 0.5564, 0.8426, 0.4353],
          [0.9576, 0.9855, 0.1127, 0.0494, 0.9233]]],


        [[[0.9636, 0.1502, 0.2663, 0.8124, 0.6775],
          [0.2963, 0.4937, 0.5248, 0.3054, 0.8550],
          [0.4647, 0.3225, 0.6812, 0.3244, 0.9137]],

         [[0.2312, 0.1562, 0.9889, 0.9965, 0.7029],
          [0.2519, 0.0765, 0.2370, 0.0087, 0.7299],
          [0.3233, 0.7194, 0.3040, 0.1910, 0.7097]],

         [[0.2522, 0.8169, 0.9949, 0.9000, 0.6066],
          [0.1716, 0.8241, 0.5534, 0.3875, 0.7326],
          [0.8174, 0.7716, 0.5967, 0.8277, 0.4918]],

         [[0.7273, 0.1836, 0.8402, 0.4715, 0.4603],
          [0.5413, 0.9876, 0.7963, 0.3272, 0.4132],
          [0.5733, 0.0092, 0.3861, 0.2801, 0.3459]]]])

4.矩阵的拼接torch.cat和torch.stack

前者是在已经有的维度上进行拼接——给定轴的维度可以不同,其他轴的维度必须相同,后者实在新的轴上进行拼接。如下报错是因为在除拼接所在维度上的其他维度不一致

后者是在新的维度上进行拼接,要求拼接的矩阵的维度都相同——dim是在哪个维度上拼接(新增)

 5.矩阵的拆分torch.split()和torch.chunk(),前者传的参数是以哪个维度进行拆分后矩阵的大小;后者是以哪个维度拆分成多少个

以上是关于pytorch进行维度变换以及形状变换的主要内容,如果未能解决你的问题,请参考以下文章

pytorch进行维度变换以及形状变换

深度学习--PyTorch维度变换自动拓展合并与分割

PyTorch中Tensor的维度变换实现

pytorch 数据维度变换

pytorch 数据维度变换

pytorch 数据维度变换