Pytorch 中 torch.cat() 函数解析
Posted 怎样才能回到过去
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 中 torch.cat() 函数解析相关的知识,希望对你有一定的参考价值。
Pytorch 中主要有两个拼接函数:
torch.stack()
torch.cat()
这里主要介绍 torch.cat()
1 函数作用
对给定维度上的输入的 Tensor 序列进行连接
torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor
2 参数解析
import torch
outputs = torch.cat(inputs, dim) -> Tensor
inputs : 待连接的张量, 必须是 Tensor, 注意连接多个 Tensor 时, 需要将多个 Tensor 放入一个 list[] 中
dim : 从那个维度进行连接, 必须小于维度的个数
3 示例
import torch
a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = torch.tensor([[3, 3, 3], [4, 4, 4]])
print("dim = 0 :", torch.cat([a, b], dim = 0))
print("dim = 1 :", torch.cat([a, b], dim = 1))
>>> dim = 0 : tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]])
>>> dim = 1 : tensor([[1, 1, 1, 3, 3, 3],
[2, 2, 2, 4, 4, 4]])
以上是关于Pytorch 中 torch.cat() 函数解析的主要内容,如果未能解决你的问题,请参考以下文章
pytorch中torch.cat() 和paddle中的paddle.concat()函数用法
pytorch中的torch.cat()矩阵拼接的用法及理解
PyTorch利用torch.cat()实现Tensor的拼接