Pytorch-数据类型
Posted vshen999
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch-数据类型相关的知识,希望对你有一定的参考价值。
1.张量数据类型
Pytorch常用的数据类型,其中FloatTensor、DoubleTensor、ByteTensor、IntTensor最常用。GPU和CPU的Tensor不相同。
- 数据类型检查使用isinstance()
import torch a = torch.randn(2,3) #torch.FloatTensor a.type() #true isinstance(a,torch.FloatTensor)
- 标量,torch.tensor(),t是小写的
import torch a = torch.tensor(2) #0 print(len(a.shape)) #torch.Size([]) print(a.size())
- 1维张量
import torch import numpy as np #torch.tensor里边放的是数 a = torch.tensor([2.1]) print(a.shape) #torch.FloatTensor里边放的是维度 a = torch.FloatTensor(2) #tensor([5.6052e-45, 0.0000e+00]) print(a) a = np.ones(2) #tensor([1., 1.], dtype=torch.float64) b = torch.from_numpy(a) print(isinstance(b,torch.DoubleTensor))
- 多维张量
import torch import numpy as np # tensor([[[0.9539, 0.4338, 0.9842], # [0.2288, 0.0569, 0.9997]]]) a = torch.rand(1,2,3) #3,返回维度 a.dim() #6,返回元素数,1*2*3 a.numel()
- 初始化张量
import torch #(0,1)之间均值分布初始化 a = torch.rand(3,3) # tensor([[0.7140, 0.3779, 0.7530], # [0.1225, 0.2168, 0.9868], # [0.6421, 0.0806, 0.1370]]) print(a) #接收一个tensor,把a的shape读出来,生成一个a的shape的均值分布 b = torch.rand_like(a) # tensor([[0.6055, 0.3282, 0.4211], # [0.9757, 0.3171, 0.5054], # [0.3429, 0.1091, 0.9734]]) print(b) #生成1-10之间的整数,生成形状是[3,3] c = torch.randint(1,10,[3,3]) # tensor([[6, 6, 9], # [8, 1, 1], # [9, 6, 8]]) print(c)
- torch.full()
#生成一个2*3,元素都是7的tnensor #生成标量使用[] a = torch.full([2,3],7) print(a)
- torch.arange()
#生成[0,10)之间的等差数列 #tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) a = torch.arange(0,10) #tensor([0, 2, 4, 6, 8]) b = torch.arange(0,10,2)
- torch.linspace()
#[0,10]之间均匀取11个点 #tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) a = torch.linspace(0,10,steps=11) print(a)
- torch.ones()、torch.zeros()、torch.eye()
#单位阵 a = torch.ones(3,3)#0矩阵 b = torch.zeros(3,4)#对角阵 c = torch.eye(3)
2.索引
索引是从第0维开始的
import torch a = torch.rand(4,3,28,28) #torch.Size([3, 28, 28]) print(a[0]) #torch.Size([28, 28]) print(a[0,1])
- :选取
import torch a = torch.rand(4,3,28,28) #选择0维度0,1两个元素,不包括2 print(a[:2]) #选择0维度最后一个元素 #索引正编号[0,1,2,3],反编号[-4,-3,-2,-1] #等价于a[3] print(a[-1:])
总的形式可以表达为:start:end:step,::2表示元素隔1个进行采样
- 选择特定的行,index_select()
import torch a = torch.rand(4,3,28,28) #第0个维度,选择特定的第0个和第三个元素 b = a.index_select(0,torch.tensor([0,3]))
- masked_select(),选择特定位置元素
import torch # tensor([[0.1956, 0.1843, 0.2313], # [0.1363, 0.4729, 0.7214], # [0.5356, 0.4904, 0.5742]]) a = torch.rand(3,3) #值大于0.5的元素位置设为1 # tensor([[0, 0, 0], # [0, 0, 1], # [1, 0, 1]], dtype=torch.uint8) mask = a.ge(0.5) #tensor([0.7214, 0.5356, 0.5742]) c = torch.masked_select(a,mask)
2.维度变换
- view()、reshape()
import torch a = torch.rand(4,1,28,28) #torch.Size([4, 784]) b = a.view(4,28*28) #view完以后会丢失a的shape信息 #b.view(4,28,28,1),这样变换语法上没有错误,但是由于shape和a不匹配,变换的数据并不是原数据,造成数据污染
- 增加维度,unsqueeze()
import torch #torch.Size([2]) a = torch.tensor([15,22]) #正数表示在哪个维度之前添加一个维度,值<a.dim()+1 #torch.Size([2, 1]) # tensor([[15], # [22]]) b = a.unsqueeze(1) #torch.Size([1, 2]) #tensor([[15, 22]]) b = a.unsqueeze(0) #负数表示在哪个维度之后插入一个维度,值≥-a.dim()-1 #torch.Size([2, 1]) b=a.unsqueeze(-1) #torch.Size([1, 2]) b=a.unsqueeze(-2) # 正, 0,1, 2, 3 # [4,3,28,28] # 负,-4,-3,-2,-1 x = torch.rand(4,3,28,28) #等价x.unsqueeze(2) q = x.unsqueeze(-3)
- squeeze(),维度减少
import torch x = torch.rand(1,1,28,1) #不填参数会把所有为1的维度减少 a = x.squeeze() #把第3维维度减少 #torch.Size([1, 1, 28]) b = x.squeeze(3)
- expand(),维度扩展
import torch #维度拓展可以把所有是1的维度扩展到需要维度 x = torch.tensor([[1,2]]) # tensor([[1, 2], # [1, 2]]) x = x.expand(2,2) y = torch.tensor([[3,4],[5,6]]) # tensor([[4, 6], # [6, 8]]) a = x +y
- transpose(),交换行列
- permute(),把原来的行列交换
import torch x = torch.rand(4,3,28,28) #交换0维和1维 #torch.Size([3, 4, 28, 28]) b = x.permute(1,0,2,3)
permute和transpose会打乱数据在内存中的位置,如果数据在内存中不连续了,使用contiguous()把数据变成连续的
以上是关于Pytorch-数据类型的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch DataLoader 错误:“类型”类型的对象没有 len()