PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别

Posted love the future

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别相关的知识,希望对你有一定的参考价值。

文章目录

PyTorch中Stack和cat

两者之间主要的区别:
cat是在现有轴进行拼接,stack要求两个tensor的shape必须一样,在一个新的轴上堆叠。

首先看一下在PyTorch中如何增加tensor的维度,我们通常使用unsqueeze

import torch
t1 = torch.tensor([1,1,1])
print(t1.unsqueeze(dim=0))
print(t1.unsqueeze(dim=1))
print(t1.shape)
print(t1.unsqueeze(dim=0).shape)
print(t1.unsqueeze(dim=1).shape)
# tensor([[1, 1, 1]])
# tensor([[1],
#         [1],
#         [1]])
# torch.Size([3])  一维
# torch.Size([1, 3])二维
# torch.Size([3, 1])二维

在PyTorch中stack会增加维度,cat在现有的轴上增加轴的长度。
cat()不存在dim=1的连接,如果需要用cat进行dim=1的连接,需要先将其进行扩维

import torch

t1 = torch.tensor([1,1,1])
t2 = torch.tensor([2,2,2])
t3 = torch.tensor([3,3,3])
print(torch.cat(
    (t1,t2,t3)
    ,dim=0
))
print(torch.stack(
    (t1,t2,t3)
    ,dim=0
))
# tensor([1, 1, 1, 2, 2, 2, 3, 3, 3]) 一维i,并没有增加维度
# tensor([[1, 1, 1],  变成了二维,增加了一维
#         [2, 2, 2],
#         [3, 3, 3]])

print(torch.stack(
    (t1,t2,t3)
    ,dim=1
))
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

print(torch.cat(
    (
         t1.unsqueeze(0)
        ,t2.unsqueeze(0)
        ,t3.unsqueeze(0)
    )
    ,dim=0
))

print(torch.cat(
    (
         t1.unsqueeze(1)
        ,t2.unsqueeze(1)
        ,t3.unsqueeze(1)
    )
    ,dim=1
))

# tensor([[1, 1, 1],
#         [2, 2, 2],
#         [3, 3, 3]])
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

TensorFlow中的Stack Vs Concat

import tensorflow as tf

t1 = tf.constant([1,1,1])
t2 = tf.constant([2,2,2])
t3 = tf.constant([3,3,3])
tf.concat(
    (t1,t2,t3)
    ,axis=0
)
#<tf.Tensor: shape=(9,), dtype=int32, numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=int32)>
tf.stack(
    (t1,t2,t3)
    ,axis=0
)
tf.concat(
    (
         tf.expand_dims(t1, 0)
        ,tf.expand_dims(t2, 0)
        ,tf.expand_dims(t3, 0)
    )    
    ,axis=0
)
# <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
# array([[1, 1, 1],
#        [2, 2, 2],
#        [3, 3, 3]], dtype=int32)>
tf.stack(
    (t1,t2,t3)
    ,axis=1
)
tf.concat(
    (
         tf.expand_dims(t1, 1)
        ,tf.expand_dims(t2, 1)
        ,tf.expand_dims(t3, 1)
    )
    ,axis=1
)
# <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
# array([[1, 2, 3],
#        [1, 2, 3],
#        [1, 2, 3]], dtype=int32)>

numpy中的Stack Vs Concatenate

import numpy as np

t1 = np.array([1,1,1])
t2 = np.array([2,2,2])
t3 = np.array([3,3,3])
np.concatenate(
    (t1,t2,t3)
    ,axis=0
)
#array([1, 1, 1, 2, 2, 2, 3, 3, 3])
np.stack(
    (t1,t2,t3)
    ,axis=0
)
np.concatenate(
    (
         np.expand_dims(t1, 0)
        ,np.expand_dims(t2, 0)
        ,np.expand_dims(t3, 0)
    )
    ,axis=0
)
# array([[1, 1, 1],
#        [2, 2, 2],
#        [3, 3, 3]])
np.stack(
    (t1,t2,t3)
    ,axis=1
)
np.concatenate(
    (
         np.expand_dims(t1, 1)
        ,np.expand_dims(t2, 1)
        ,np.expand_dims(t3, 1)
    )
    ,axis=1
)
# array([[1, 2, 3],
#        [1, 2, 3],
#        [1, 2, 3]])

对比

关于cat和stack的实例

多张图片合并成单个批次

有3各张量分别代表一张图片,每张图片的维度是3,将3张图片合并成一个批次,即(batch_szie,c,w,h)的形式

import torch
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)

torch.stack(
    (t1,t2,t3)
    ,dim=0
).shape

## output ##
#torch.Size([3, 3, 28, 28])

将多个小批次合并成一个Batch

有三个small batch,其shape是(1,3,28,28),将其合并成([3, 3, 28, 28])

import torch
t1 = torch.zeros(1,3,28,28)
t2 = torch.zeros(1,3,28,28)
t3 = torch.zeros(1,3,28,28)
torch.cat(
    (t1,t2,t3)
    ,dim=0
).shape

## output ##
#torch.Size([3, 3, 28, 28])

将图像与现有的batch连接

import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
    (
        batch
        ,torch.stack(
            (t1,t2,t3)
            ,dim=0
        )
    )
    ,dim=0
).shape

## output ##
#torch.Size([6, 3, 28, 28])

另外一种方法,两种方式最后的实现结果一样

import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
    (
        batch
        ,t1.unsqueeze(0)
        ,t2.unsqueeze(0)
        ,t3.unsqueeze(0)
    )
    ,dim=0
).shape

## output ##
#torch.Size([6, 3, 28, 28])

以上是关于PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别

Pytorch 中 torch.cat() 函数解析

pytorch数据拼接与拆分

pytorch的Tensor的操作

PyTorch 中的 index_select 和 tensor[sequence] 之间有啥区别吗?

PyTorch 1.8 和 Tensorflow 2.5,我该用哪个?