在 PyTorch 中,是啥让张量具有非连续内存?
Posted
技术标签:
【中文标题】在 PyTorch 中,是啥让张量具有非连续内存?【英文标题】:In PyTorch, what makes a tensor have non-contiguous memory?在 PyTorch 中,是什么让张量具有非连续内存? 【发布时间】:2019-06-03 08:31:38 【问题描述】:根据this SO 和this PyTorch discussion,PyTorch 的view
函数仅适用于连续内存,而reshape
则不行。在第二个链接中,作者甚至声称:
[
view
] 将在非连续张量上引发错误。
但是张量何时具有非连续记忆?
【问题讨论】:
【参考方案1】:This 是一个非常好的答案,它在 NumPy 的上下文中解释了该主题。 PyTorch 的工作原理基本相同。它的文档通常不会提到函数输出是否(非)连续,但这是可以根据操作类型(具有一些经验和对实现的理解)猜测的。根据经验,大多数操作在构造新张量时会保持连续性。如果操作在原地数组上运行并更改其跨步,您可能会看到不连续的输出。下面举几个例子
import torch
t = torch.randn(10, 10)
def check(ten):
print(ten.is_contiguous())
check(t) # True
# flip sets the stride to negative, but element j is still adjacent to
# element i, so it is contiguous
check(torch.flip(t, (0,))) # True
# if we take every 2nd element, adjacent elements in the resulting array
# are not adjacent in the input array
check(t[::2]) # False
# if we transpose, we lose contiguity, as in case of NumPy
check(t.transpose(0, 1)) # False
# if we transpose twice, we first lose and then regain contiguity
check(t.transpose(0, 1).transpose(0, 1)) # True
一般来说,如果你有不连续的张量t
,你可以通过调用t = t.contiguous()
使其连续。如果t
是连续的,则调用t.contiguous()
本质上是无操作的,因此您可以这样做而不会冒很大的性能损失。
【讨论】:
【参考方案2】:我认为您的标题contiguous memory
有点误导。据我了解,PyTorch 中的contiguous
表示张量中的相邻元素是否在内存中实际上彼此相邻。举个简单的例子:
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # x is contiguous
y = torch.transpose(0, 1) # y is non-contiguous
根据documentation of tranpose()
:
返回一个张量,它是输入的转置版本。给定尺寸 dim0 和 dim1 交换。
生成的张量与输入张量共享其底层存储,因此更改一个张量的内容会改变另一个张量的内容。
所以上面例子中的x
和y
共享同一个内存空间。但是如果你用is_contiguous()
检查它们的连续性,你会发现x
是连续的,而y
不是。现在你会发现contiguity
并不是指contiguous memory
。
由于x
是连续的,x[0][0]
和x[0][1]
在内存中彼此相邻。但是y[0][0]
和y[0][1]
不是。这就是contiguous
的意思。
【讨论】:
以上是关于在 PyTorch 中,是啥让张量具有非连续内存?的主要内容,如果未能解决你的问题,请参考以下文章