PyTorch 中的 .flatten() 和 .view(-1) 有啥区别?
Posted
技术标签:
【中文标题】PyTorch 中的 .flatten() 和 .view(-1) 有啥区别?【英文标题】:What is the difference between .flatten() and .view(-1) in PyTorch?PyTorch 中的 .flatten() 和 .view(-1) 有什么区别? 【发布时间】:2019-12-05 15:04:52 【问题描述】:.flatten()
和 .view(-1)
在 PyTorch 中展平张量。有什么区别?
.flatten()
是否复制张量的数据?
.view(-1)
更快吗?
有没有.flatten()
不起作用的情况?
【问题讨论】:
我认为它们与.flatten()
的默认参数相同,但.flatten()
允许您传递start_dim
和end_dim
以获得更复杂的行为。例如,torch.ones(10, 4, 5, 6).flatten(start_dim=1, end_dim=2)
返回一个形状为 (10, 20, 6)
的张量。
【参考方案1】:
除了@adeelh 的评论,还有一个区别:torch.flatten()
的结果是.reshape()
,而differences between .reshape()
and .view()
是:
[...]
torch.reshape
可能会返回原始张量的副本或视图。您不能指望返回视图或副本。另一个区别是 reshape() 可以对连续张量和非连续张量进行操作,而 view() 只能对连续张量进行操作。另请参阅此处了解连续的含义
对于上下文:
社区请求了一段时间的flatten
功能,在Issue #7743之后,该功能在PR #8578中实现。
你可以看到 flatten here 的实现,在 return
行中可以看到对 .reshape()
的调用。
【讨论】:
【参考方案2】:flatten
只是 convenient alias 的一个常见用例 view
。1
还有其他几个:
Function | Equivalent view logic |
---|---|
flatten() |
view(-1) |
flatten(start, end) |
view(*t.shape[:start], -1, *t.shape[end+1:]) |
squeeze() |
view(*[s for s in t.shape if s != 1]) |
unsqueeze(i) |
view(*t.shape[:i-1], 1, *t.shape[i:]) |
请注意,flatten
允许您使用 start_dim
和 end_dim
参数来展平特定的连续维度子集。
-
实际上表面上等同于
reshape
。
【讨论】:
以上是关于PyTorch 中的 .flatten() 和 .view(-1) 有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch tensor(1.3):Reshape Operations--flatten
PyTorch torch.flatten()与nn.Flatten()的区别
PyTorch torch.flatten()与nn.Flatten()的区别