torch.linspace,unsqueeze()以及squeeze()函数
Posted wmy-ncut
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.linspace,unsqueeze()以及squeeze()函数相关的知识,希望对你有一定的参考价值。
1.torch.linspace(start,end,steps=100,dtype)
作用是返回一个一维的tensor(张量),其中dtype是返回的数据类型。
import torch print(torch.linspace(-1,1,5))
输出结果为:tensor([-1.0000, -0.5000, 0.0000, 0.5000, 1.0000])
2.unsqueeze()函数
在指定位置增加维度。
import torch a=torch.arange(0,6) #a是一维向量 b=a.reshape(2,3) #b是二维向量 c=b.unsqueeze(1) #c是三维向量,在b的第二维上增加一个维度 print(a) print(b) print(c) print(c.size())
a的维度为1x6
b的维度为2x3
b的维度为2x1x3
若想在倒数第二个维度增加一个维度,则c=b.unsqueeze(-1)
3.squeeze()函数
可去掉维度为1的维度。
import torch a=torch.arange(0,6) #a是一维向量 b=a.reshape(2,3) c=b.unsqueeze(1) print(c) print(c.size()) d=c.squeeze(1) print(d) print(d.size())
输出结果为:
以上是关于torch.linspace,unsqueeze()以及squeeze()函数的主要内容,如果未能解决你的问题,请参考以下文章