仿真基本功PyTorch中torch.Tensor变形的相关函数使用方法及说明
Posted AbaloneVH
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了仿真基本功PyTorch中torch.Tensor变形的相关函数使用方法及说明相关的知识,希望对你有一定的参考价值。
PyTorch中经常用到Tensor的变形。近期写GNN相关代码时,大量出现相关操作。为了避免以后仍需重复测试,特将相关内容、测试结果及个人理解整理在下面。自用的同时服务大家。同时也欢迎大家收藏和帮助我一起完善,谢谢各位~
PyTorch中torch.Tensor变形的相关函数使用方法及说明
0. 本文考虑的基本Tensor张量
import torch
a = torch.randn([3, 2, 4, 5])
b = torch.randn([3, 2, 5])
1. Tensor的形状
想知道Tensor的形状,可以使用shape函数
print(a.shape) # 输出结果:torch.Size([3, 2, 4, 5])
2. Tensor的维度置换
Tensor的维度置换可以使用permute()。
(1) 当Tensor仅有其中两个维度需要置换时,可以使用transpose()。
虽然此时transpose()的速度与permute()相同(经测试验证),但是permute()中需要把Tensor的所有维度写全,transpose()中只需要写需要置换的两个维度即可。
例如: a.transpose(1, 3) 与 a.permute(0, 3, 2, 1) 效果相同,但后者必须把a的四个维度写全
print(a.transpose(1, 3).shape) # 输出结果:torch.Size([3, 5, 4, 2])
print(a.transpose(3, 1).shape) # transpose()的两个参数的顺序不影响结果,这里输出结果与上面相同(个人强迫症,喜欢把小的放在前面)
(2) 当Tensor有超过两个维度需要置换时,可以使用permute()或者多个transpose()。
例如: a.transpose(1, 3).transpose(2, 3) 与 a.permute(0, 3, 1, 2) 效果相同但建议使用permute(),因为对较大规模的四维张量测试(使用小规模张量时考虑运行速度的意义不大)后发现,permute()比transpose()快50%左右。
3. Tensor的变形
(1) Tensor的一般变形: reshape()
(2) Tensor的临时变形: view()
注:如果只需要使用一次Tensor变形后的结果,可以使用view()而非reshape(),这样可以节省内存。因为reshape()需要开辟新的内存,而view()不需要。
具体可参考: PyTorch:view() 与 reshape() 区别详解
print(a.reshape([3*2, 4*5]).shape) # 输出结果:torch.Size([6, 20])
print(a.view([3*2, 4*5]).shape) # 此时的形状为[6, 20], 数据与Tensor a共享
c = a.view([3*2, 4*5]) # 如果将view()处理后的结果赋值给其他变量,那么仍然需要开辟新的内存存储变形后的结果,此时view()的优势将不复存在
4. Tensor的拼接
在实际应用中,有时需要将张量的某一维度进行拓展。
(1) 一般的拼接:将待拼接的ensor用[]放在一起(可以超过两个)。要求:各待拼接Tensor除了待拼接的维度,其余维度的大小均相同。
注:Tensor的拼接不可以仅使用[],因为拼接后每个元素的类型为Tensor,拼接结果的类型却是List,不符合需求。
d = b.reshape([3, 2, 1, 5]) # d -> shape: [3, 2, 1, 5]
e1 = torch.cat([a, d], dim=2) # e -> shape: [3, 2, 4+1, 5]
# 更简洁高效的写法
e2 = torch.cat([a, b.view([3, 2, 1, 5])], dim=2) # 结果与e1相同,但是更节省内存
(2) 通过拼接实现某一维度的扩展
f = torch.cat([b.view([3, 2, 1, 5])]*4, dim=2) # f -> shape: [3, 2, 4, 5]
5. (本文核心内容)Tensor的拼接与变形
这里介绍“如何保证Tensor在各种拼接和变形后,各维度数据不发生错乱”,这是本文要介绍的核心内容,也是个人经过测试后给出的万无一失的办法
任务:将shape为 [ 3 , 2 , 5 ] [3, 2, 5] [3,2,5]的Tensor b复制扩展为shape为 [ 3 , 2 , 4 , 5 ] [3, 2, 4, 5] [3,2,4,5]的Tensor,然后reshape为 [ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5] [2∗4∗3,5]的Tensor
方法1
c = torch.cat([b.view([3, 2, 1, 5])]*4, dim=2) # c -> shape: [3, 2, 4, 5]
d = c.permute([1, 2, 0, 3]).reshape([-1, 5]) # d -> shape: [2*4*3, 5]
注:第二行代码不能直接reshape,否则得到的Tensor的shape实际上是 [ 3 ∗ 2 ∗ 4 , 5 ] [3*2*4, 5] [3∗2∗4,5], 此时在与其他同(或类)shape的Tensor拼接或运算时,维度与数据会出现错乱。
例如,还有一个shape是 [ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5] [2∗4∗3,5]的Tensor与d相加。如果维度错乱,那么就不是对应元素相加了;而转换为 [ 2 , 4 , 3 , 5 ] [2,4,3,5] [2,4,3,5]的shape又额外增加代码。
方法2
c = torch.cat([b]*4, dim=0) # c-> shape:[4*3, 2, 5]
d = c.transpose(0, 1).reshape([-1, 5]) # d -> shape: [2*4*3, 5]
注:这里的第二行代码同样不能直接reshape,否则得到的Tensor的shape实际上是 [ 4 ∗ 3 ∗ 2 , 5 ] [4*3*2, 5] [4∗3∗2,5]
讲解重点之前,我们先以b为例说明。b的shape是 [ 3 , 2 , 5 ] [3, 2, 5] [3,2,5], 将b具体表示出来(randn随机生成的一个例子)就是
tensor([[[ 0.3543, -0.9587, -0.6313, 1.5067, 1.4628],
[-0.0671, 1.1080, 0.5200, -0.2528, 0.2759]],
[[ 0.1023, -1.7001, 0.0717, 0.2326, 0.1111],
[-0.8022, 0.6989, -0.6247, -1.1926, -0.3376]],
[[ 0.4788, 0.3146, 0.4460, -0.0280, -1.0335],
[-2.4860, 0.7232, 0.5325, 0.4981, -0.0081]]])
由上面的例子可以看出,shape中的“5”是指 [ 5 , ] [5, ] [5,]的向量,shape中的“2”是指有2个 [ 5 , ] [5, ] [5,]的向量,shape中的“3”是指有3个(2个 [ 5 , ] [5, ] [5,]的向量)。
下面我们再回头关注方法2里的第一行代码。使用
torch.cat([b]*n, dim=i)
形式的代码,就是将b的第i个维度的大小(d_i)乘上了n,表示第i个维度有n个 [d_i, ] 的量,因此乘的形式是 nd_i 而非 d_in。此时由reshape()进行变形,才能得到正确的 [ . . . , n , d i , . . . ] [..., n, d_i, ...] [...,n,di,...] 而非 [ . . . , d i , n , . . . ] [..., d_i, n, ...] [...,di,n,...]。( [ . . . , n , d i , . . . ] [..., n, d_i, ...] [...,n,di,...]表示有 …个(n个(d_i个(…的向量))))
如果对你有用,请帮忙点个赞鼓励我创作哦~
以上是关于仿真基本功PyTorch中torch.Tensor变形的相关函数使用方法及说明的主要内容,如果未能解决你的问题,请参考以下文章