-1 在 pytorch 视图中是啥意思?

Posted

技术标签:

【中文标题】-1 在 pytorch 视图中是啥意思?【英文标题】:What does -1 mean in pytorch view?-1 在 pytorch 视图中是什么意思? 【发布时间】:2018-11-20 09:13:39 【问题描述】:

正如问题所说,-1 在 pytorch view 中做了什么?

>>> a = torch.arange(1, 17)
>>> a
tensor([  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.])

>>> a.view(1,-1)
tensor([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
          11.,  12.,  13.,  14.,  15.,  16.]])

>>> a.view(-1,1)
tensor([[  1.],
        [  2.],
        [  3.],
        [  4.],
        [  5.],
        [  6.],
        [  7.],
        [  8.],
        [  9.],
        [ 10.],
        [ 11.],
        [ 12.],
        [ 13.],
        [ 14.],
        [ 15.],
        [ 16.]])

它 (-1) 会生成额外的维度吗? 它的行为是否与 numpy reshape -1 相同?

【问题讨论】:

据我所知(我不是专业人士..),给定的维度 -1 将适应其他维度。所以a.view(-1,1) 将产生一个维度为17x1 的向量,因为有17 个值 - 所以v.view(1,-1) 将产生一个1x17 向量...... 【参考方案1】:

是的,它的行为类似于numpy.reshape() 中的-1,即会推断此维度的实际值,以便视图中的元素数与原始元素数相匹配。

例如:

import torch

x = torch.arange(6)

print(x.view(3, -1))      # inferred size will be 2 as 6 / 3 = 2
# tensor([[ 0.,  1.],
#         [ 2.,  3.],
#         [ 4.,  5.]])

print(x.view(-1, 6))      # inferred size will be 1 as 6 / 6 = 1
# tensor([[ 0.,  1.,  2.,  3.,  4.,  5.]])

print(x.view(1, -1, 2))   # inferred size will be 3 as 6 / (1 * 2) = 3
# tensor([[[ 0.,  1.],
#          [ 2.,  3.],
#          [ 4.,  5.]]])

# print(x.view(-1, 5))    # throw error as there's no int N so that 5 * N = 6
# RuntimeError: invalid argument 2: size '[-1 x 5]' is invalid for input with 6 elements

print(x.view(-1, -1, 3))  # throw error as only one dimension can be inferred
# RuntimeError: invalid argument 1: only one dimension can be inferred

【讨论】:

如果我们自己有 -1 怎么办?例如我面前有这个:correct[:k].view(-1)。在这种特殊情况下,这有什么作用? @CharlieParker:这将使张量变平(类似于torch.flatten(correct)),即返回一个包含所有元素的单一维度的张量。例如,在我的答案中的命令之后运行 x.view(-1) 将返回 tensor([0., 1., 2., 3., 4., 5.]),即一个尺寸为 6 的单维张量。【参考方案2】:

我喜欢本杰明给出的答案https://***.com/a/50793899/1601580

是的,它的行为类似于 numpy.reshape() 中的 -1,即将推断此维度的实际值,以便视图中的元素数与原始元素数匹配。

但我认为在使用单个 -1 即 tensor.view(-1) 调用它时,可能对您(或至少对我而言不是)不直观的奇怪案例边缘情况。 我的猜测是它的工作方式与往常完全相同,只是因为您提供一个数字来查看它假设您想要一个维度。如果您有 tensor.view(-1, Dnew) 它会产生一个张量两个维度/索引,但会根据张量的原始维度确保第一个维度的大小正确。假设你有(D1, D2),你有Dnew=D1*D2,那么新维度将为1。

对于您可以运行的带有代码的真实示例:

import torch

x = torch.randn(1, 5)
x = x.view(-1)
print(x.size())

x = torch.randn(2, 4)
x = x.view(-1, 8)
print(x.size())

x = torch.randn(2, 4)
x = x.view(-1)
print(x.size())

x = torch.randn(2, 4, 3)
x = x.view(-1)
print(x.size())

输出:

torch.Size([5])
torch.Size([1, 8])
torch.Size([8])
torch.Size([24])

历史/背景

我觉得是一个很好的例子(common case early on in pytorch beforeflatten 层是official added这个通用代码):

class Flatten(nn.Module):
    def forward(self, input):
        # input.size(0) usually denotes the batch size so we want to keep that
        return input.view(input.size(0), -1)

用于顺序。在这个视图中,x.view(-1) 是一个奇怪的扁平层,但缺少挤压(即添加维度 1)。添加或删除此压缩通常对于代码实际运行很重要。

【讨论】:

【参考方案3】:

我猜这类似于np.reshape:

新形状应与原始形状兼容。如果是整数,则结果将是该长度的一维数组。一个形状维度可以是-1。在这种情况下,值是从数组的长度和剩余维度推断出来的

如果您有a = torch.arange(1, 18),您可以通过多种方式查看它,例如a.view(-1,6)a.view(-1,9)a.view(3,-1) 等。

【讨论】:

如果我们自己有 -1 怎么办?例如我面前有这个:correct[:k].view(-1)。在这种特殊情况下会做什么?【参考方案4】:

From the PyTorch documentation:

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

【讨论】:

如果我们自己有 -1 怎么办?例如我面前有这个:correct[:k].view(-1)。在这种特殊情况下,这有什么作用?【参考方案5】:

-1 推断为 2,例如,如果你有

>>> a = torch.rand(4,4)
>>> a.size()
torch.size([4,4])
>>> y = x.view(16)
>>> y.size()
torch.size([16])
>>> z = x.view(-1,8) # -1 is generally inferred as 2  i.e (2,8)
>>> z.size()
torch.size([2,8])

【讨论】:

如果我们自己有 -1 怎么办?例如我面前有这个:correct[:k].view(-1)。在这种特殊情况下会做什么?【参考方案6】:

-1 是 PyTorch 的别名,用于“在其他维度都已指定的情况下推断此维度”(即原始产品与新产品的商)。这是取自numpy.reshape() 的约定。

因此示例中的t.view(1,17) 将等效于t.view(1,-1)t.view(-1,17)

【讨论】:

以上是关于-1 在 pytorch 视图中是啥意思?的主要内容,如果未能解决你的问题,请参考以下文章

“||”是啥意思在 var 语句中是啥意思? [复制]

“内容”是啥意思:在招摇/openapi“响应”中是啥意思:

问号和点运算符是啥意思?在 C# 6.0 中是啥意思?

*/ 在 XPath 中是啥意思?

-> 在 C++ 中是啥意思? [复制]

“**”在python中是啥意思? [复制]