PyTorch利用torch.cat()实现Tensor的拼接

Posted 算法与编程之美

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch利用torch.cat()实现Tensor的拼接相关的知识,希望对你有一定的参考价值。

问题

方法


import torch
from torch import nn

conv1 = nn.Conv2d(
    in_channels=3,
    out_channels=32,
    kernel_size=3,
    stride=1,
    padding=1
)

conv2 = nn.Conv2d(
    in_channels=3,
    out_channels=16,
    kernel_size=3,
    stride=1,
    padding=1
)

x = torch.rand(128, 3, 224, 224)

x1 = conv1(x) # [128, 32, 224, 224]
x2 = conv2(x) # [128, 16, 224, 224]

# 表示对dim=1维进行cat操作,其他维度均不变
out = torch.cat([x1, x2], dim=1) 
print(out.shape) #[128, 48, 224, 224]


结语

以上是关于PyTorch利用torch.cat()实现Tensor的拼接的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch 中 torch.cat() 函数解析

Pytorch中的torch.cat()函数

pytorch 常用函数参数详解

pytorch中的torch.cat()矩阵拼接的用法及理解

pytorch中torch.cat() 和paddle中的paddle.concat()函数用法

pytorch-torch2:张量计算和连接