终于有人搞懂了详解 torch.unsqueeze() 和 torch.squeeze()

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了终于有人搞懂了详解 torch.unsqueeze() 和 torch.squeeze()相关的知识,希望对你有一定的参考价值。

1. 入门测试

  • torch.squeeze(input, dim = None, out = None): 返回一个tensor
    • 当dim不设值时,去掉输入的tensor的所有维度为1的维度;
    • 当dim为某一整数(0<=dim<input.dim())时,判断dim维的维度是否为1,若是则去掉,否则不变。
    • 另外,当input是一维的时候,squeeze不变
>>> x = torch.zeros(1,1,2,1,3)
>>> x.dim()
5
>>> torch.squeeze(x).size() # 去掉dim=1的维度
torch.Size([2, 3])
>>> torch.squeeze(x,0).size()  # dim=0表示第一维,且第一维的维度为1,所以去掉
torch.Size([1, 2, 1, 3])
>>> torch.squeeze(x,3).size()
torch.Size([1, 1, 2, 3])
>>> torch.squeeze(x,2).size()  # dim=2,第三维的维度为2!=1,所以不变
torch.Size([1, 1, 2, 1, 3])
  • torch.unqueeze(input, dim, out=None): 和squeeze作用相反,unsqueeze()在dim维插入一个维度为1的维,例如原来x是n×m维的,torch.unqueeze(x,0)这返回1×n×m的tensor
>>> x = torch.tensor([1,2,3])  # dim=1,即(3)
>>> torch.unsqueeze(x, 1)  # 变为(3,1)的矩阵
tensor([[ 1],
        [ 2],
        [ 3]])
  1. squeeze:压缩(降维)
  2. unqueeze:解压缩(升维)

2. 深入研究

2.1 torch.unsqueeze 详解

torch.unsqueeze(input, dim, out=None)
  • 作用:扩展维度

返回一个新的张量,对输入的既定位置插入维度 1

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果dim为负,则将会被转化dim+input.dim()+1

import torch

x = torch.tensor([1, 2, 3])
print(x)
print(x.size())
print(torch.unsqueeze(x, -1))  # dim=-1作用效果等同dim=1
print(torch.unsqueeze(x, -1).size())
>>>
tensor([1, 2, 3])
torch.Size([3])
------------
tensor([[1],
        [2],
        [3]])
torch.Size([3, 1])
import torch

x = torch.tensor([1, 2, 3])
print(x)
print(x.size())
print(torch.unsqueeze(x, -2))  # dim=-2:转化为(1, input_dim)
print(torch.unsqueeze(x, -2).size())
>>>
tensor([1, 2, 3])
torch.Size([3])
------------
tensor([[1, 2, 3]])
torch.Size([1, 3])

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)

  • 参数:
    • tensor (Tensor) – 输入张量
    • dim (int) – 插入维度的索引
    • out (Tensor, optional) – 结果张量
import torch

x = torch.Tensor([1, 2, 3, 4])  # torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称。

print('-' * 50)
print(x)  # tensor([1., 2., 3., 4.])
print(x.size())  # torch.Size([4])
print(x.dim())  # 1
print(x.numpy())  # [1. 2. 3. 4.]

print('-' * 50)
print(torch.unsqueeze(x, 0))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim())  # 2
print(torch.unsqueeze(x, 0).numpy())  # [[1. 2. 3. 4.]]

print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, 1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim())  # 2

print('-' * 50)
print(torch.unsqueeze(x, -1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, -1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim())  # 2

print('-' * 50)
print(torch.unsqueeze(x, -2))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, -2).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, -2).dim())  # 2

# 边界测试
# 说明:A dim value within the range [-input.dim() - 1, input.dim() + 1) (左闭右开)can be used.
# print('-' * 50)
# print(torch.unsqueeze(x, -3))
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)

# print('-' * 50)
# print(torch.unsqueeze(x, 2))
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

# 为何取值范围要如此设计呢?
# 原因:方便操作
# 0(-2)-行扩展
# 1(-1)-列扩展
# 正向:我们在0,1位置上扩展
# 逆向:我们在-2,-1位置上扩展
# 维度扩展:1维->2维,2维->3维,...,n维->n+1维
# 维度降低:n维->n-1维,n-1维->n-2维,...,2维->1维

# 以 1维->2维 为例,

# 从【正向】的角度思考:

# torch.Size([4])
# 最初的 tensor([1., 2., 3., 4.]) 是 1维,我们想让它扩展成 2维,那么,可以有两种扩展方式:

# 一种是:扩展成 1行4列 ,即 tensor([[1., 2., 3., 4.]])
# 针对第一种,扩展成 [1, 4]的形式,那么,在 dim=0 的位置上添加 1

# 另一种是:扩展成 4行1列,即
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
# 针对第二种,扩展成 [4, 1]的形式,那么,在dim=1的位置上添加 1

# 从【逆向】的角度思考:
# 原则:一般情况下, "-1" 是代表的是【最后一个元素】
# 在上述的原则下,
# 扩展成[1, 4]的形式,就变成了,在 dim=-2 的的位置上添加 1
# 扩展成[4, 1]的形式,就变成了,在 dim=-1 的的位置上添加 1

dim值对应增加维度的方式:

2.2 unsqueeze_和 unsqueeze 的区别

unsqueeze_unsqueeze 实现一样的功能,区别在于 unsqueeze_in_place 操作,即 unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变,想要获取 unsqueeze 后的值必须赋予个新值, unsqueeze_ 则会对自己改变

print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])

b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([1., 2., 3., 4.])


print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])

print(a.unsqueeze_(1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

2.3 torch.squeeze 详解

torch.squeeze(input, dim=None, out=None)
  • 作用:降维

1. 将输入张量形状中的1 去除并返回。

如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

2. 当给定dim时,那么挤压操作只在给定维度上。

例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
  • 参数:
    • input (Tensor) – 输入张量
    • dim (int, optional) – 如果给定,则input只会在给定维度挤压
    • out (Tensor, optional) – 输出张量

为何只去掉 1 呢?

多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度

print("*" * 50)

m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])

n = torch.squeeze(m, 0)  # 当给定dim时,那么挤压操作只在给定维度上
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m, 1)
print(n.size())  # torch.Size([2, 2, 1, 2])

n = torch.squeeze(m, 2)
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m, 3)
print(n.size())  # torch.Size([2, 1, 2, 2])

print("@" * 50)
p = torch.zeros(2, 1, 1)
print(p)
# tensor([[[0.]],
#         [[0.]]])
print(p.numpy())
# [[[0.]]
#  [[0.]]]

print(p.size())
# torch.Size([2, 1, 1])

q = torch.squeeze(p)
print(q)
# tensor([0., 0.])

print(q.numpy())
# [0. 0.]

print(q.size())
# torch.Size([2])

print(torch.zeros(3, 2).numpy())
# [[0. 0.]
#  [0. 0.]
#  [0. 0.]]

参考Link


加油!

感谢!

努力!

以上是关于终于有人搞懂了详解 torch.unsqueeze() 和 torch.squeeze()的主要内容,如果未能解决你的问题,请参考以下文章

今天终于把爬虫的Ajax请求搞懂了

今天终于把爬虫的Ajax请求搞懂了

终于搞懂了vue 的 render 函数 -_-|||

torch.unsqueeze

多次尝试学习,终于搞懂了微服务架构

终于搞懂了shell bash cmd...