pytorch维度变换

Posted wx5cbd4315aefc1

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch维度变换相关的知识,希望对你有一定的参考价值。


import torch as t
a=t.rand(4,1,28,28)
a.shape
torch.Size([4, 1, 28, 28])
a.view(4,28*28)
tensor([[0.7170, 0.7973, 0.8322,  ..., 0.2318, 0.9531, 0.6618],
[0.7864, 0.9424, 0.7775, ..., 0.7293, 0.7722, 0.3481],
[0.4799, 0.8308, 0.0384, ..., 0.5429, 0.3318, 0.3176],
[0.3548, 0.2738, 0.1126, ..., 0.8103, 0.5105, 0.5830]])
a.view(4,28*28).shape
torch.Size([4, 784])
a.view(4*28,28).shape
torch.Size([112, 28])
a.view(4*1,28,28).shape
torch.Size([4, 28, 28])
b=a.view(4,784)
b.view(4,28,28,1)
tensor([[[[0.7170],
[0.7973],
[0.8322],
...,
[0.7292],
[0.7369],
[0.0919]],

[[0.2956],
[0.8355],
[0.8593],
...,
[0.3294],
[0.9853],
[0.3265]],

[[0.1045],
[0.9306],
[0.4242],
...,
[0.4399],
[0.5458],
[0.6823]],

...,

[[0.0085],
[0.6165],
[0.5018],
...,
[0.2905],
[0.7364],
[0.2715]],

[[0.0685],
[0.6373],
[0.7948],
...,
[0.1856],
[0.7264],
[0.2514]],

[[0.7125],
[0.1486],
[0.4361],
...,
[0.2318],
[0.9531],
[0.6618]]],


[[[0.7864],
[0.9424],
[0.7775],
...,
[0.2481],
[0.6750],
[0.0833]],

[[0.4633],
[0.0623],
[0.2279],
...,
[0.6857],
[0.6348],
[0.2083]],

[[0.7915],
[0.0695],
[0.2783],
...,
[0.1555],
[0.5421],
[0.1337]],

...,

[[0.0955],
[0.4038],
[0.6088],
...,
[0.3266],
[0.4750],
[0.5062]],

[[0.5249],
[0.0367],
[0.4000],
...,
[0.3639],
[0.4786],
[0.0517]],

[[0.1864],
[0.3414],
[0.5211],
...,
[0.7293],
[0.7722],
[0.3481]]],


[[[0.4799],
[0.8308],
[0.0384],
...,
[0.7505],
[0.6558],
[0.8692]],

[[0.8836],
[0.7475],
[0.3443],
...,
[0.1412],
[0.2885],
[0.0483]],

[[0.7127],
[0.2985],
[0.2680],
...,
[0.1241],
[0.6580],
[0.3919]],

...,

[[0.0063],
[0.3349],
[0.7492],
...,
[0.5369],
[0.4494],
[0.8487]],

[[0.2440],
[0.9463],
[0.5812],
...,
[0.9820],
[0.1489],
[0.1279]],

[[0.9872],
[0.7186],
[0.7177],
...,
[0.5429],
[0.3318],
[0.3176]]],


[[[0.3548],
[0.2738],
[0.1126],
...,
[0.9997],
[0.8620],
[0.0051]],

[[0.3392],
[0.4705],
[0.5175],
...,
[0.4567],
[0.5824],
[0.0641]],

[[0.4711],
[0.5184],
[0.5050],
...,
[0.4252],
[0.6838],
[0.0144]],

...,

[[0.3049],
[0.8823],
[0.6849],
...,
[0.4563],
[0.6089],
[0.1411]],

[[0.5463],
[0.3497],
[0.5929],
...,
[0.8492],
[0.6190],
[0.5833]],

[[0.1441],
[0.9260],
[0.8446],
...,
[0.8103],
[0.5105],
[0.5830]]]])
a.shape
torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-4).shape
torch.Size([4, 1, 1, 28, 28])
a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(5).shape
---------------------------------------------------------------------------

IndexError Traceback (most recent call last)

<ipython-input-14-b54eab361a50> in <module>
----> 1 a.unsqueeze(5).shape


IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
a=t.tensor([1.2,2.3])
a.unsqueeze(-1)
tensor([[1.2000],
[2.3000]])
a.unsqueeze(0)
tensor([[1.2000, 2.3000]])
b=t.rand(32)
f=t.rand(4,32,14,14)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
b.shape
torch.Size([1, 32, 1, 1])
b.shape
torch.Size([1, 32, 1, 1])
b.squeeze().shape
torch.Size([32])
b.squeeze(0).shape
torch.Size([32, 1, 1])
b.squeeze(-1).shape
torch.Size([1, 32, 1])
b.squeeze(1).shape
torch.Size([1, 32, 1, 1])
b.squeeze(-4).shape
torch.Size([32, 1, 1])
a=t.rand(4,32,14,14)
b.shape
torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape
torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
b.shape
torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape
torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape
torch.Size([4, 32, 32, 32])
a.shape
torch.Size([4, 32, 14, 14])
a=t.rand(4,3,32,32)
a.shape
torch.Size([4, 3, 32, 32])
a1=a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
---------------------------------------------------------------------------

RuntimeError Traceback (most recent call last)

<ipython-input-37-b44bf620887c> in <module>
----> 1 a1=a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)


RuntimeError: view size is not compatible with input tensors size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
a1=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
a2=a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
a1.shape,a2.shape
(torch.Size([4, 3, 32, 32]), torch.Size([4, 3, 32, 32]))
t.all(t.eq(a,a1))
tensor(False)
t.all(t.eq(a,a2))
tensor(True)
a=t.rand(4,3,28,28)
a.transpose(1,3).shape
torch.Size([4, 28, 28, 3])
b=t.rand(4,3,28,32)
b.transpose(1,3).shape
torch.Size([4, 32, 28, 3])
b.transpose(1,3).transpose(1,2).shape
torch.Size([4, 28, 32, 3])
b.permute(0,2,3,1).shape
torch.Size([4, 28, 32, 3])


以上是关于pytorch维度变换的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 数据维度变换

pytorch 数据维度变换

pytorch 数据维度变换

pytorch维度变换

pytorch张量数据索引切片与维度变换操作大全(非常全)

pytorch进行维度变换以及形状变换