Pytorch 中 torch.cat() 函数解析

Posted 怎样才能回到过去

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 中 torch.cat() 函数解析相关的知识,希望对你有一定的参考价值。

Pytorch 中主要有两个拼接函数:

  1. torch.stack()

  1. torch.cat()

这里主要介绍 torch.cat()

1 函数作用

  • 对给定维度上的输入的 Tensor 序列进行连接

  • torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor

2 参数解析

import torch

outputs = torch.cat(inputs, dim) -> Tensor
  • inputs : 待连接的张量, 必须是 Tensor, 注意连接多个 Tensor 时, 需要将多个 Tensor 放入一个 list[] 中

  • dim : 从那个维度进行连接, 必须小于维度的个数

3 示例

import torch

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = torch.tensor([[3, 3, 3], [4, 4, 4]])
print("dim = 0 :", torch.cat([a, b], dim = 0))
print("dim = 1 :", torch.cat([a, b], dim = 1))

>>> dim = 0 : tensor([[1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3],
                    [4, 4, 4]])
>>> dim = 1 : tensor([[1, 1, 1, 3, 3, 3],
                    [2, 2, 2, 4, 4, 4]])

以上是关于Pytorch 中 torch.cat() 函数解析的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch中的torch.cat()函数

pytorch中torch.cat() 和paddle中的paddle.concat()函数用法

pytorch中的torch.cat()矩阵拼接的用法及理解

PyTorch利用torch.cat()实现Tensor的拼接

PyTorch利用torch.cat()实现Tensor的拼接

pytorch-torch2:张量计算和连接