PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别
Posted love the future
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别相关的知识,希望对你有一定的参考价值。
文章目录
PyTorch中Stack和cat
两者之间主要的区别:
cat是在现有轴进行拼接,stack要求两个tensor的shape必须一样,在一个新的轴上堆叠。
首先看一下在PyTorch中如何增加tensor的维度,我们通常使用unsqueeze
import torch
t1 = torch.tensor([1,1,1])
print(t1.unsqueeze(dim=0))
print(t1.unsqueeze(dim=1))
print(t1.shape)
print(t1.unsqueeze(dim=0).shape)
print(t1.unsqueeze(dim=1).shape)
# tensor([[1, 1, 1]])
# tensor([[1],
# [1],
# [1]])
# torch.Size([3]) 一维
# torch.Size([1, 3])二维
# torch.Size([3, 1])二维
在PyTorch中stack会增加维度,cat在现有的轴上增加轴的长度。
cat()不存在dim=1的连接,如果需要用cat进行dim=1的连接,需要先将其进行扩维
import torch
t1 = torch.tensor([1,1,1])
t2 = torch.tensor([2,2,2])
t3 = torch.tensor([3,3,3])
print(torch.cat(
(t1,t2,t3)
,dim=0
))
print(torch.stack(
(t1,t2,t3)
,dim=0
))
# tensor([1, 1, 1, 2, 2, 2, 3, 3, 3]) 一维i,并没有增加维度
# tensor([[1, 1, 1], 变成了二维,增加了一维
# [2, 2, 2],
# [3, 3, 3]])
print(torch.stack(
(t1,t2,t3)
,dim=1
))
# tensor([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
print(torch.cat(
(
t1.unsqueeze(0)
,t2.unsqueeze(0)
,t3.unsqueeze(0)
)
,dim=0
))
print(torch.cat(
(
t1.unsqueeze(1)
,t2.unsqueeze(1)
,t3.unsqueeze(1)
)
,dim=1
))
# tensor([[1, 1, 1],
# [2, 2, 2],
# [3, 3, 3]])
# tensor([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
TensorFlow中的Stack Vs Concat
import tensorflow as tf
t1 = tf.constant([1,1,1])
t2 = tf.constant([2,2,2])
t3 = tf.constant([3,3,3])
tf.concat(
(t1,t2,t3)
,axis=0
)
#<tf.Tensor: shape=(9,), dtype=int32, numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=int32)>
tf.stack(
(t1,t2,t3)
,axis=0
)
tf.concat(
(
tf.expand_dims(t1, 0)
,tf.expand_dims(t2, 0)
,tf.expand_dims(t3, 0)
)
,axis=0
)
# <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
# array([[1, 1, 1],
# [2, 2, 2],
# [3, 3, 3]], dtype=int32)>
tf.stack(
(t1,t2,t3)
,axis=1
)
tf.concat(
(
tf.expand_dims(t1, 1)
,tf.expand_dims(t2, 1)
,tf.expand_dims(t3, 1)
)
,axis=1
)
# <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
# array([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]], dtype=int32)>
numpy中的Stack Vs Concatenate
import numpy as np
t1 = np.array([1,1,1])
t2 = np.array([2,2,2])
t3 = np.array([3,3,3])
np.concatenate(
(t1,t2,t3)
,axis=0
)
#array([1, 1, 1, 2, 2, 2, 3, 3, 3])
np.stack(
(t1,t2,t3)
,axis=0
)
np.concatenate(
(
np.expand_dims(t1, 0)
,np.expand_dims(t2, 0)
,np.expand_dims(t3, 0)
)
,axis=0
)
# array([[1, 1, 1],
# [2, 2, 2],
# [3, 3, 3]])
np.stack(
(t1,t2,t3)
,axis=1
)
np.concatenate(
(
np.expand_dims(t1, 1)
,np.expand_dims(t2, 1)
,np.expand_dims(t3, 1)
)
,axis=1
)
# array([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
对比
关于cat和stack的实例
多张图片合并成单个批次
有3各张量分别代表一张图片,每张图片的维度是3,将3张图片合并成一个批次,即(batch_szie,c,w,h)的形式
import torch
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.stack(
(t1,t2,t3)
,dim=0
).shape
## output ##
#torch.Size([3, 3, 28, 28])
将多个小批次合并成一个Batch
有三个small batch,其shape是(1,3,28,28),将其合并成([3, 3, 28, 28])
import torch
t1 = torch.zeros(1,3,28,28)
t2 = torch.zeros(1,3,28,28)
t3 = torch.zeros(1,3,28,28)
torch.cat(
(t1,t2,t3)
,dim=0
).shape
## output ##
#torch.Size([3, 3, 28, 28])
将图像与现有的batch连接
import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
(
batch
,torch.stack(
(t1,t2,t3)
,dim=0
)
)
,dim=0
).shape
## output ##
#torch.Size([6, 3, 28, 28])
另外一种方法,两种方式最后的实现结果一样
import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
(
batch
,t1.unsqueeze(0)
,t2.unsqueeze(0)
,t3.unsqueeze(0)
)
,dim=0
).shape
## output ##
#torch.Size([6, 3, 28, 28])
以上是关于PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别