如何在pytorch中展平张量?
Posted
技术标签:
【中文标题】如何在pytorch中展平张量?【英文标题】:How do I flatten a tensor in pytorch? 【发布时间】:2019-08-28 00:51:00 【问题描述】:给定一个多维张量,我如何将其展平以使其具有单维?
例如:
>>> t = torch.rand([2, 3, 5])
>>> t.shape
torch.Size([2, 3, 5])
如何将其展平以使其具有形状:
torch.Size([30])
【问题讨论】:
【参考方案1】:TL;DR:torch.flatten()
使用在v0.4.1 中引入并在v1.0rc1 中记录的torch.flatten()
:
>>> t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
对于 v0.4.1 及更早版本,请使用 t.reshape(-1)
。
与t.reshape(-1)
:
如果请求的视图在内存中是连续的
这将等同于t.view(-1)
,并且不会复制内存。
否则将等同于t.
contiguous()
.view(-1)
。
其他非选项:
t.view(-1)
won't copy memory, but may not work depending on original size and stride
t.resize(-1)
给出RuntimeError
(见下文)
t.resize(t.numel())
warning about being a low-level method
(见下文讨论)
(注意:pytorch
的reshape()
可能会更改数据,但numpy
's reshape()
won't。)
t.resize(t.numel())
需要一些讨论。 torch.Tensor.resize_
documentation 说:
存储被重新解释为 C 连续,忽略当前步幅(除非目标大小等于当前大小,在这种情况下张量保持不变)
鉴于新的(1, numel())
大小将忽略当前步幅,元素的顺序可能 以与reshape(-1)
不同的顺序出现。但是,“大小”可能表示内存大小,而不是张量的大小。
如果t.resize(-1)
既方便又高效,那就太好了,但torch 1.0.1.post2
,t = torch.rand([2, 3, 5]); t.resize(-1)
提供:
RuntimeError: requested resize to -1 (-1 elements in total), but the given
tensor has a size of 2x2 (4 elements). autograd's resize can only change the
shape of a given tensor, while preserving the number of elements.
我对此here提出了功能要求,但一致认为resize()
是一种低级方法,应优先使用reshape()
。
【讨论】:
您可以指定-1
。 squeeze
操作并不是必需的。
reshape
可以做到。
reshape()
可能会返回原始张量的副本或视图。上面答案中的 Doco 链接。
torch.flatten(var, start_dim=1),start_dim参数很棒。【参考方案2】:
使用torch.reshape
,只能传递一个维度来展平它。如果您不希望对维度进行硬编码,只需指定 -1
即可推断出正确的维度。
>>> x = torch.tensor([[1,2], [3,4]])
>>> x.reshape(-1)
tensor([1, 2, 3, 4])
编辑:
对于您的示例:
【讨论】:
以torch.rand([2, 3, 5])
失败。请参阅我的输出答案。
@TomHale 请查看随附的屏幕截图。不理解投反对票的原因,而不是简单的推动。
我的使用reshape
注意reshape()
总是复制内存。
并非总是如此,正如here所解释的那样【参考方案3】:
flatten()
在C++ PyTorch code 下方使用reshape()
。
使用flatten()
,您可以执行以下操作:
import torch
input = torch.rand(2, 3, 4).cuda()
print(input.shape) # torch.Size([2, 3, 4])
print(input.flatten(start_dim=0, end_dim=1).shape) # torch.Size([6, 4])
如果你想使用reshape
,那么扁平化你会这样做:
print(input.reshape((6,4)).shape) # torch.Size([6, 4])
但通常你会像这样做简单的展平:
print(input.reshape(-1).shape) # torch.Size([24])
print(input.flatten().shape) # torch.Size([24])
注意:
reshape()
比view()
更健壮。它适用于任何张量,而view()
仅适用于张量t
其中t.is_contiguous()==True
。
【讨论】:
【参考方案4】:你可以做一个简单的t.view(-1)
>>>t = torch.rand([2, 3, 5])
>>>t = t.view(-1)
>>>t.shape
torch.Size([30])
【讨论】:
以上是关于如何在pytorch中展平张量?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 pytorch 和 tensorflow 中使用张量核心?