torch.stack() 和 torch.cat() 函数有啥区别?

Posted

技术标签:

【中文标题】torch.stack() 和 torch.cat() 函数有啥区别?【英文标题】:What's the difference between torch.stack() and torch.cat()?torch.stack() 和 torch.cat() 函数有什么区别? 【发布时间】:2019-06-15 20:40:55 【问题描述】:

OpenAI 用于强化学习的 REINFORCE 和 actor-critic 示例具有以下代码:

REINFORCE:

policy_loss = torch.cat(policy_loss).sum()

actor-critic:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

一个使用torch.cat,另一个使用torch.stack

据我了解,the doc 并没有明确区分它们。

我很高兴知道这些功能之间的区别。

【问题讨论】:

如果您对张量的可变长度嵌套列表感兴趣,这里的链接似乎有解决方案:***.com/questions/55050717/… 和 discuss.pytorch.org/t/… 【参考方案1】:

stack

沿新维度连接张量序列。

cat

在给定维度中连接给定的 seq 张量序列。

因此,如果 AB 的形状为 (3, 4),torch.cat([A, B], dim=0) 的形状为 (6, 4),torch.stack([A, B], dim=0) 的形状为 (2, 3, 4)。

【讨论】:

因此,torch.stack([A,B],dim = 0) 等价于 torch.cat([A.unsqueeze(0),b.unsqueeze(0)],dim = 0 ) .因此,如果您发现自己在组合张量之前执行了许多 unsqueeze() 操作,您可能会使用 stack() 来简化代码。 只是为了补充,在问题的 OpenAI 示例中,torch.stacktorch.cat 可以在任一代码行中互换使用,因为 torch.stack(tensors).sum() == torch.cat(tensors).sum()【参考方案2】:

原始答案缺少一个独立的好例子,所以这里是:

import torch

# stack vs cat

# cat "extends" a list in the given dimension e.g. adds more rows or columns

x = torch.randn(2, 3)
print(f'x.size()')

# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'xnew_from_cat.size()')

# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'xnew_from_cat.size()')

print()

# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'xnew_from_stack.size()')

xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'xnew_from_stack.size()')

xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'xnew_from_stack.size()')

# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'xnew_from_stack.size()')

print('I like to think of xnew_from_stack as a \"tensor list\" that you can pop from the front')

输出:

torch.Size([2, 3])
torch.Size([6, 3])
torch.Size([2, 9])
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 3, 4])
torch.Size([4, 2, 3])
I like to think of xnew_from_stack as a "tensor list"

以下是定义供参考:

cat:连接给定维度中给定的 seq 张量序列。结果是特定尺寸会改变尺寸,例如dim=0 那么您正在向行中添加元素,这会增加列空间的维度。

stack:沿新维度连接张量序列。我喜欢将此视为火炬“追加”操作,因为您可以通过从前面“弹出”来索引/获取原始张量。没有参数,它将张量附加到张量的前面。


相关:

这里是来自 pytorch 论坛的链接,其中讨论了这个问题:https://discuss.pytorch.org/t/best-way-to-convert-a-list-to-a-tensor/59949/8 虽然我希望 tensor.torch 将张量的嵌套列表转换为具有许多维度的大张量,以尊重嵌套列表的深度。李>

更新:使用相同大小的嵌套列表

def tensorify(lst):
    """
    List must be nested list of tensors (with no varying lengths within a dimension).
    Nested list of nested lengths [D1, D2, ... DN] -> tensor([D1, D2, ..., DN)

    :return: nested list D
    """
    # base case, if the current list is not nested anymore, make it into tensor
    if type(lst[0]) != list:
        if type(lst) == torch.Tensor:
            return lst
        elif type(lst[0]) == torch.Tensor:
            return torch.stack(lst, dim=0)
        else:  # if the elements of lst are floats or something like that
            return torch.tensor(lst)
    current_dimension_i = len(lst)
    for d_i in range(current_dimension_i):
        tensor = tensorify(lst[d_i])
        lst[d_i] = tensor
    # end of loop lst[d_i] = tensor([D_i, ... D_0])
    tensor_lst = torch.stack(lst, dim=0)
    return tensor_lst

这里有一些单元测试(我没有编写更多测试,但它适用于我的真实代码,所以我相信它没问题。如果需要,请随时通过添加更多测试来帮助我):


def test_tensorify():
    t = [1, 2, 3]
    print(tensorify(t).size())
    tt = [t, t, t]
    print(tensorify(tt))
    ttt = [tt, tt, tt]
    print(tensorify(ttt))

if __name__ == '__main__':
    test_tensorify()
    print('Done\a')

【讨论】:

【参考方案3】:
t1 = torch.tensor([[1, 2],
                   [3, 4]])

t2 = torch.tensor([[5, 6],
                   [7, 8]])
torch.stack torch.cat
'Stacks' a sequence of tensors along a new dimension: 'Concatenates' a sequence of tensors along an existing dimension:

这些函数类似于numpy.stacknumpy.concatenate

【讨论】:

以上是关于torch.stack() 和 torch.cat() 函数有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章

深度之眼PyTorch训练营第二期 ---2张量操作与线性回归

如何使用 torch.stack 函数

[python]torch.cat和numpy.concatenate对应拼接

Pytorch中的torch.cat()函数

torch.cat

pytorch中torch.cat() 和paddle中的paddle.concat()函数用法