Pytorch-数据类型

Posted vshen999

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch-数据类型相关的知识,希望对你有一定的参考价值。

1.张量数据类型

Pytorch常用的数据类型,其中FloatTensor、DoubleTensor、ByteTensor、IntTensor最常用。GPU和CPU的Tensor不相同。技术图片

 

 

  •  数据类型检查使用isinstance()
import torch

a = torch.randn(2,3)
#torch.FloatTensor
a.type()
#true
isinstance(a,torch.FloatTensor)
  • 标量,torch.tensor(),t是小写的
import torch

a = torch.tensor(2)
#0
print(len(a.shape))
#torch.Size([])
print(a.size())
  • 1维张量
import torch
import numpy as np

#torch.tensor里边放的是数
a = torch.tensor([2.1])
print(a.shape)

#torch.FloatTensor里边放的是维度
a = torch.FloatTensor(2)
#tensor([5.6052e-45, 0.0000e+00])
print(a)

a = np.ones(2)
#tensor([1., 1.], dtype=torch.float64)
b = torch.from_numpy(a)
print(isinstance(b,torch.DoubleTensor))
  • 多维张量
import torch
import numpy as np

# tensor([[[0.9539, 0.4338, 0.9842],
#          [0.2288, 0.0569, 0.9997]]])
a = torch.rand(1,2,3)
#3,返回维度
a.dim()
#6,返回元素数,1*2*3
a.numel()
  • 初始化张量
import torch

#(0,1)之间均值分布初始化
a = torch.rand(3,3)
# tensor([[0.7140, 0.3779, 0.7530],
#         [0.1225, 0.2168, 0.9868],
#         [0.6421, 0.0806, 0.1370]])
print(a)

#接收一个tensor,把a的shape读出来,生成一个a的shape的均值分布
b = torch.rand_like(a)
# tensor([[0.6055, 0.3282, 0.4211],
#         [0.9757, 0.3171, 0.5054],
#         [0.3429, 0.1091, 0.9734]])
print(b)

#生成1-10之间的整数,生成形状是[3,3]
c = torch.randint(1,10,[3,3])
# tensor([[6, 6, 9],
#         [8, 1, 1],
#         [9, 6, 8]])
print(c)
  • torch.full()
#生成一个2*3,元素都是7的tnensor
#生成标量使用[]
a = torch.full([2,3],7)
print(a)
  • torch.arange()
#生成[0,10)之间的等差数列
#tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
a = torch.arange(0,10)
#tensor([0, 2, 4, 6, 8])
b = torch.arange(0,10,2)
  • torch.linspace()
#[0,10]之间均匀取11个点
#tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
a = torch.linspace(0,10,steps=11)
print(a)
  • torch.ones()、torch.zeros()、torch.eye()
#单位阵
a = torch.ones(3,3)#0矩阵
b = torch.zeros(3,4)#对角阵
c = torch.eye(3)

2.索引

索引是从第0维开始的

import torch

a = torch.rand(4,3,28,28)

#torch.Size([3, 28, 28])
print(a[0])
#torch.Size([28, 28])
print(a[0,1])
  • :选取
import torch

a = torch.rand(4,3,28,28)
#选择0维度0,1两个元素,不包括2
print(a[:2])
#选择0维度最后一个元素
#索引正编号[0,1,2,3],反编号[-4,-3,-2,-1]
#等价于a[3]
print(a[-1:])

总的形式可以表达为:start:end:step,::2表示元素隔1个进行采样

  • 选择特定的行,index_select()
import torch

a = torch.rand(4,3,28,28)

#第0个维度,选择特定的第0个和第三个元素
b = a.index_select(0,torch.tensor([0,3]))
  • masked_select(),选择特定位置元素
import torch

# tensor([[0.1956, 0.1843, 0.2313],
#         [0.1363, 0.4729, 0.7214],
#         [0.5356, 0.4904, 0.5742]])
a = torch.rand(3,3)
#值大于0.5的元素位置设为1
# tensor([[0, 0, 0],
#         [0, 0, 1],
#         [1, 0, 1]], dtype=torch.uint8)
mask = a.ge(0.5)
#tensor([0.7214, 0.5356, 0.5742])
c = torch.masked_select(a,mask)

2.维度变换

  • view()、reshape()
import torch

a = torch.rand(4,1,28,28)
#torch.Size([4, 784])
b = a.view(4,28*28)
#view完以后会丢失a的shape信息
#b.view(4,28,28,1),这样变换语法上没有错误,但是由于shape和a不匹配,变换的数据并不是原数据,造成数据污染
  • 增加维度,unsqueeze()
import torch

#torch.Size([2])
a = torch.tensor([15,22])
#正数表示在哪个维度之前添加一个维度,值<a.dim()+1
#torch.Size([2, 1])
# tensor([[15],
#         [22]])
b = a.unsqueeze(1)
#torch.Size([1, 2])
#tensor([[15, 22]])
b = a.unsqueeze(0)

#负数表示在哪个维度之后插入一个维度,值≥-a.dim()-1
#torch.Size([2, 1])
b=a.unsqueeze(-1)
#torch.Size([1, 2])
b=a.unsqueeze(-2)
# 正, 0,1, 2, 3
#    [4,3,28,28]
# 负,-4,-3,-2,-1
x = torch.rand(4,3,28,28)
#等价x.unsqueeze(2)
q = x.unsqueeze(-3)
  • squeeze(),维度减少
import torch

x = torch.rand(1,1,28,1)
#不填参数会把所有为1的维度减少
a = x.squeeze()
#把第3维维度减少
#torch.Size([1, 1, 28])
b = x.squeeze(3)
  • expand(),维度扩展
import torch

#维度拓展可以把所有是1的维度扩展到需要维度
x = torch.tensor([[1,2]])
# tensor([[1, 2],
#         [1, 2]])
x = x.expand(2,2)
y = torch.tensor([[3,4],[5,6]])
# tensor([[4, 6],
#         [6, 8]])
a = x +y
  • transpose(),交换行列

技术图片

 

  •  permute(),把原来的行列交换
import torch

x = torch.rand(4,3,28,28)
#交换0维和1维
#torch.Size([3, 4, 28, 28])
b = x.permute(1,0,2,3)

permute和transpose会打乱数据在内存中的位置,如果数据在内存中不连续了,使用contiguous()把数据变成连续的

以上是关于Pytorch-数据类型的主要内容,如果未能解决你的问题,请参考以下文章

利用pytorch的载入训练npy类型数据代码

PyTorch学习数据格式转换

Pytorch-数据类型

PyTorch DataLoader 错误:“类型”类型的对象没有 len()

Pytorch 保存模型用户警告:无法检索网络类型容器的源代码

pytorch tensor/数据类型转化