torch.stack() 和 torch.cat() 函数有啥区别?
Posted
技术标签:
【中文标题】torch.stack() 和 torch.cat() 函数有啥区别?【英文标题】:What's the difference between torch.stack() and torch.cat()?torch.stack() 和 torch.cat() 函数有什么区别? 【发布时间】:2019-06-15 20:40:55 【问题描述】:OpenAI 用于强化学习的 REINFORCE 和 actor-critic 示例具有以下代码:
REINFORCE:
policy_loss = torch.cat(policy_loss).sum()
actor-critic:
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
一个使用torch.cat
,另一个使用torch.stack
。
据我了解,the doc 并没有明确区分它们。
我很高兴知道这些功能之间的区别。
【问题讨论】:
如果您对张量的可变长度嵌套列表感兴趣,这里的链接似乎有解决方案:***.com/questions/55050717/… 和 discuss.pytorch.org/t/… 【参考方案1】:stack
沿新维度连接张量序列。
cat
在给定维度中连接给定的 seq 张量序列。
因此,如果 A
和 B
的形状为 (3, 4),torch.cat([A, B], dim=0)
的形状为 (6, 4),torch.stack([A, B], dim=0)
的形状为 (2, 3, 4)。
【讨论】:
因此,torch.stack([A,B],dim = 0) 等价于 torch.cat([A.unsqueeze(0),b.unsqueeze(0)],dim = 0 ) .因此,如果您发现自己在组合张量之前执行了许多 unsqueeze() 操作,您可能会使用 stack() 来简化代码。 只是为了补充,在问题的 OpenAI 示例中,torch.stack
和 torch.cat
可以在任一代码行中互换使用,因为 torch.stack(tensors).sum() == torch.cat(tensors).sum()
。【参考方案2】:
原始答案缺少一个独立的好例子,所以这里是:
import torch
# stack vs cat
# cat "extends" a list in the given dimension e.g. adds more rows or columns
x = torch.randn(2, 3)
print(f'x.size()')
# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'xnew_from_cat.size()')
# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'xnew_from_cat.size()')
print()
# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'xnew_from_stack.size()')
xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'xnew_from_stack.size()')
xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'xnew_from_stack.size()')
# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'xnew_from_stack.size()')
print('I like to think of xnew_from_stack as a \"tensor list\" that you can pop from the front')
输出:
torch.Size([2, 3])
torch.Size([6, 3])
torch.Size([2, 9])
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 3, 4])
torch.Size([4, 2, 3])
I like to think of xnew_from_stack as a "tensor list"
以下是定义供参考:
cat:连接给定维度中给定的 seq 张量序列。结果是特定尺寸会改变尺寸,例如dim=0 那么您正在向行中添加元素,这会增加列空间的维度。
stack:沿新维度连接张量序列。我喜欢将此视为火炬“追加”操作,因为您可以通过从前面“弹出”来索引/获取原始张量。没有参数,它将张量附加到张量的前面。
相关:
这里是来自 pytorch 论坛的链接,其中讨论了这个问题:https://discuss.pytorch.org/t/best-way-to-convert-a-list-to-a-tensor/59949/8 虽然我希望tensor.torch
将张量的嵌套列表转换为具有许多维度的大张量,以尊重嵌套列表的深度。李>
更新:使用相同大小的嵌套列表
def tensorify(lst):
"""
List must be nested list of tensors (with no varying lengths within a dimension).
Nested list of nested lengths [D1, D2, ... DN] -> tensor([D1, D2, ..., DN)
:return: nested list D
"""
# base case, if the current list is not nested anymore, make it into tensor
if type(lst[0]) != list:
if type(lst) == torch.Tensor:
return lst
elif type(lst[0]) == torch.Tensor:
return torch.stack(lst, dim=0)
else: # if the elements of lst are floats or something like that
return torch.tensor(lst)
current_dimension_i = len(lst)
for d_i in range(current_dimension_i):
tensor = tensorify(lst[d_i])
lst[d_i] = tensor
# end of loop lst[d_i] = tensor([D_i, ... D_0])
tensor_lst = torch.stack(lst, dim=0)
return tensor_lst
这里有一些单元测试(我没有编写更多测试,但它适用于我的真实代码,所以我相信它没问题。如果需要,请随时通过添加更多测试来帮助我):
def test_tensorify():
t = [1, 2, 3]
print(tensorify(t).size())
tt = [t, t, t]
print(tensorify(tt))
ttt = [tt, tt, tt]
print(tensorify(ttt))
if __name__ == '__main__':
test_tensorify()
print('Done\a')
【讨论】:
【参考方案3】:t1 = torch.tensor([[1, 2],
[3, 4]])
t2 = torch.tensor([[5, 6],
[7, 8]])
torch.stack |
torch.cat |
---|---|
'Stacks' a sequence of tensors along a new dimension: | 'Concatenates' a sequence of tensors along an existing dimension: |
这些函数类似于numpy.stack
和numpy.concatenate
。
【讨论】:
以上是关于torch.stack() 和 torch.cat() 函数有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章
深度之眼PyTorch训练营第二期 ---2张量操作与线性回归