pytorch进行维度变换以及形状变换
Posted Coding With you.....
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch进行维度变换以及形状变换相关的知识,希望对你有一定的参考价值。
1.形状变换中长勇的函数是reshape和view,那么这两个函数的使用有什么不同呢
首先说相同之处:都可以将张量的形状进行变换,便于我们在不同的模块中进行使用
然后说明不同之处:view是要求张量的物理存储是连续的,如果不是连续的则会报错,当然z如果想要继续使用该函数,需要使用contiguous()函数进行转换——这时的作用就可以等价于reshape
a = torch.arange(9).reshape(3, 3) # 初始化张量a b1 = a.permute(1, 0) #这个相当于是浅拷贝,只进行了复制 变换,并没有生成新的内存空间 b = a.permute(1, 0).contiguous() # 转置,并转换为符合连续性条件的tensor 这个就是深拷贝,b不是a的原始地址 而是重新分配的地址 m=b.storage().data_ptr() m1=b1.storage().data_ptr() m2=m1=a.storage().data_ptr() m!=m1=m2
因此,如果是连续的物理空间下,就使用view比较节省空间;如果是不连续的情况下,两种方法使用的效果是一样的——都会开辟新的空间
a.reshape = a.view() & a.contiguous().view() 当是有序的时等价于前者;当是无序的时等价于后者
2.进行维度变换torch.squeeze() 和torch.unsqueeze()
torch.squeeze()这个函数主要是对数据维度进行压缩,a = torch.rand(1,2,3,4,5) # 初始化张量a b=torch.squeeze(a) #这里有1 2 3 4 5 降维成了2 3 4 5 去掉维数为1的维度;如果指定的维度不为1 比如torch.squeeze(a,2) ,则不作处理——则相当于是没有影响的 还是1 2 3 4 5 torch.unsqueeze(tensor,N)这个函数主要是对数据维度进行扩充,N填几就在第几个维度上进行扩充 举个例子: a = torch.rand(1,2,3,4,5) # 初始化张量a c=torch.unsqueeze(a,1) #1 1 2 3 4 5 c1=torch.unsqueeze(a,2) #1 2 1 3 4 5 c2=torch.unsqueeze(a,3) #1 2 3 1 4 5 c3=torch.unsqueeze(a,4) #1 2 3 4 1 5 c4=torch.unsqueeze(a,0) #1 1 2 3 4 5 c5=torch.unsqueeze(a,5) #1 2 3 4 5 1
3.tensor的转置
如果是二维的矩阵,可以采用torch.t或者torch.transpose进行转置
如果是高维的张量,可以使用permute()进行转置
下面是permute转置前后的结果:
tensor([[[[0.0059, 0.7617, 0.6693, 0.3750, 0.2394], [[0.2555, 0.2535, 0.1496, 0.4983, 0.7569], [[0.3709, 0.5796, 0.0188, 0.9509, 0.1464],
[[0.2963, 0.4937, 0.5248, 0.3054, 0.8550], [[0.4647, 0.3225, 0.6812, 0.3244, 0.9137], | tensor([[[[0.0059, 0.7617, 0.6693, 0.3750, 0.2394], [[0.4180, 0.3297, 0.9128, 0.4802, 0.3164], [[0.0389, 0.5162, 0.0794, 0.6116, 0.5534], [[0.0541, 0.6623, 0.4360, 0.9220, 0.3167],
[[0.2312, 0.1562, 0.9889, 0.9965, 0.7029], [[0.2522, 0.8169, 0.9949, 0.9000, 0.6066], [[0.7273, 0.1836, 0.8402, 0.4715, 0.4603], |
4.矩阵的拼接torch.cat和torch.stack
前者是在已经有的维度上进行拼接——给定轴的维度可以不同,其他轴的维度必须相同,后者实在新的轴上进行拼接。如下报错是因为在除拼接所在维度上的其他维度不一致
后者是在新的维度上进行拼接,要求拼接的矩阵的维度都相同——dim是在哪个维度上拼接(新增)
5.矩阵的拆分torch.split()和torch.chunk(),前者传的参数是以哪个维度进行拆分后矩阵的大小;后者是以哪个维度拆分成多少个
以上是关于pytorch进行维度变换以及形状变换的主要内容,如果未能解决你的问题,请参考以下文章