如何使用 torch.stack 函数
Posted
技术标签:
【中文标题】如何使用 torch.stack 函数【英文标题】:How to use torch.stack function 【发布时间】:2019-02-16 17:00:46 【问题描述】:我有一个关于 torch.stack 的问题
我有 2 个张量,a.shape=(2, 3, 4) 和 b.shape=(2, 3)。 如何在没有就地操作的情况下堆叠它们?
【问题讨论】:
【参考方案1】:堆叠需要相同数量的维度。一种方法是解压和堆叠。例如:
a.size() # 2, 3, 4
b.size() # 2, 3
b = torch.unsqueeze(b, dim=2) # 2, 3, 1
# torch.unsqueeze(b, dim=-1) does the same thing
torch.stack([a, b], dim=2) # 2, 3, 5
【讨论】:
您想要的是将torch.cat 与unsqueeze
一起使用。 torch.stack 创建一个 NEW 维度,并且所有提供的张量必须是相同的大小。
这个答案与torch.stack([a, b], dim=2)
不正确,相反,您想使用@drevicko 正确提到的torch.cat([a,b], dim=2)
。 torch.cat
连接给定维度中的序列,而torch.stack
连接新维度中的序列,如下所述:***.com/questions/54307225/…。
这不会运行。相反,您将收到“RuntimeError: stack expects each tensor to be equal size, but got [2, 3, 4] at entry 0 and [2, 3, 1] at entry 1'【参考方案2】:
使用 pytorch 1.2 或 1.4 arjoonn 的答案对我不起作用。
我在 pytorch 1.2 和 1.4 中使用了 torch.cat
而不是 torch.stack
:
>>> import torch
>>> a = torch.randn([2, 3, 4])
>>> b = torch.randn([2, 3])
>>> b = b.unsqueeze(dim=2)
>>> b.shape
torch.Size([2, 3, 1])
>>> torch.cat([a, b], dim=2).shape
torch.Size([2, 3, 5])
如果你想使用torch.stack
,张量的维度必须相同:
>>> a = torch.randn([2, 3, 4])
>>> b = torch.randn([2, 3, 4])
>>> torch.stack([a, b]).shape
torch.Size([2, 2, 3, 4])
这是另一个例子:
>>> t = torch.tensor([1, 1, 2])
>>> stacked = torch.stack([t, t, t], dim=0)
>>> t.shape, stacked.shape, stacked
(torch.Size([3]),
torch.Size([3, 3]),
tensor([[1, 1, 2],
[1, 1, 2],
[1, 1, 2]]))
使用stack
,您可以使用dim
参数指定在哪个维度上堆叠具有相同维度的张量。
【讨论】:
【参考方案3】:假设你有两个张量 a, b 尺寸相等,即 a ( A, B, C) 所以 b (A, B , C) 一个例子
a=torch.randn(2,3,4)
b=torch.randn(2,3,4)
print(a.size()) # 2, 3, 4
print(b.size()) # 2, 3, 4
f=torch.stack([a, b], dim=2) # 2, 3, 2, 4
f
如果它们不是相同的暗淡,它就不会起作用。小心!!
【讨论】:
以上是关于如何使用 torch.stack 函数的主要内容,如果未能解决你的问题,请参考以下文章