如何使用 torch.stack 函数

Posted

技术标签:

【中文标题】如何使用 torch.stack 函数【英文标题】:How to use torch.stack function 【发布时间】:2019-02-16 17:00:46 【问题描述】:

我有一个关于 torch.stack 的问题

我有 2 个张量,a.shape=(2, 3, 4) 和 b.shape=(2, 3)。 如何在没有就地操作的情况下堆叠它们

【问题讨论】:

【参考方案1】:

堆叠需要相同数量的维度。一种方法是解压和堆叠。例如:

a.size()  # 2, 3, 4
b.size()  # 2, 3
b = torch.unsqueeze(b, dim=2)  # 2, 3, 1
# torch.unsqueeze(b, dim=-1) does the same thing

torch.stack([a, b], dim=2)  # 2, 3, 5

【讨论】:

您想要的是将torch.cat 与unsqueeze 一起使用。 torch.stack 创建一个 NEW 维度,并且所有提供的张量必须是相同的大小。 这个答案与torch.stack([a, b], dim=2) 不正确,相反,您想使用@drevicko 正确提到的torch.cat([a,b], dim=2)torch.cat 连接给定维度中的序列,而torch.stack 连接新维度中的序列,如下所述:***.com/questions/54307225/…。 这不会运行。相反,您将收到“RuntimeError: stack expects each tensor to be equal size, but got [2, 3, 4] at entry 0 and [2, 3, 1] at entry 1'【参考方案2】:

使用 pytorch 1.2 或 1.4 arjoonn 的答案对我不起作用。

我在 pytorch 1.2 和 1.4 中使用了 torch.cat 而不是 torch.stack

>>> import torch
>>> a = torch.randn([2, 3, 4])
>>> b = torch.randn([2, 3])
>>> b = b.unsqueeze(dim=2)
>>> b.shape
torch.Size([2, 3, 1])
>>> torch.cat([a, b], dim=2).shape
torch.Size([2, 3, 5])

如果你想使用torch.stack,张量的维度必须相同:

>>> a = torch.randn([2, 3, 4])
>>> b = torch.randn([2, 3, 4])
>>> torch.stack([a, b]).shape
torch.Size([2, 2, 3, 4])

这是另一个例子:

>>> t = torch.tensor([1, 1, 2])
>>> stacked = torch.stack([t, t, t], dim=0)
>>> t.shape, stacked.shape, stacked

(torch.Size([3]),
 torch.Size([3, 3]),
 tensor([[1, 1, 2],
         [1, 1, 2],
         [1, 1, 2]]))

使用stack,您可以使用dim 参数指定在哪个维度上堆叠具有相同维度的张量。

【讨论】:

【参考方案3】:

假设你有两个张量 a, b 尺寸相等,即 a ( A, B, C) 所以 b (A, B , C) 一个例子

a=torch.randn(2,3,4)
b=torch.randn(2,3,4)
print(a.size())  # 2, 3, 4
print(b.size()) # 2, 3, 4

f=torch.stack([a, b], dim=2)  # 2, 3, 2, 4
f

如果它们不是相同的暗淡,它就不会起作用。小心!!

【讨论】:

以上是关于如何使用 torch.stack 函数的主要内容,如果未能解决你的问题,请参考以下文章

如何将张量列表转换为 Torch::Tensor?

Pytorch 中 torch.cat() 函数解析

深度之眼PyTorch训练营第二期 ---2张量操作与线性回归

如何如何使用MOD函数?

c函数指针和指针函数如何使用何定义;如何调用使用

如何创建自定义 JQuery 函数以及如何使用它?