常用Tensor操作
Posted a-runner
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了常用Tensor操作相关的知识,希望对你有一定的参考价值。
最近在看开发文档们就顺便记录一下。
1: torch.index_select(input, dim, indices,out=None)
沿着指定维度对输入进行切片,取index
中指定的相应项(index
为一个LongTensor),然后返回到一个新的张量,
返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
注意: 返回的张量不与原始张量共享内存空间。
参数:
- input (Tensor) – 输入张量
- dim (int) – 索引的轴
- index (LongTensor) – 包含索引下标的一维张量
- out (Tensor, optional) – 目标张量
>>> x = torch.randn(3, 4) >>> x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677 0.6219 -0.7954 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 3x4] >>> indices = torch.LongTensor([0, 2]) >>> torch.index_select(x, 0, indices) 1.2045 2.4084 0.4001 1.1372 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 2x4] >>> torch.index_select(x, 1, indices) 1.2045 0.4001 0.5596 0.6219 1.3635 -0.5414 [torch.FloatTensor of size 3x2]
2 torch.masked_select(input, mask,out=None)
根据掩码张量mask
中的二元值,取输入张量中的指定项( mask
为一个 ByteTensor),将取值返回到一个新的1D张量,
张量 mask
须跟input
张量有相同数量的元素数目,但形状或维度不需要相同。
注意: 返回的张量不与原始张量共享内存空间。
参数:
- input (Tensor) – 输入张量
- mask (ByteTensor) – 掩码张量,包含了二元索引值
- out (Tensor, optional) – 目标张量
例子:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5) # 构建大于0.5的Bool tensor
>>> mask
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
3 torch.nonzero (input, out) 返回非0索引
返回一个包含输入input
中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
如果输入input
有n
维,则输出的索引张量output
的形状为 z x n, 这里 z 是输入张量input
中所有非零元素的个数。
参数:
- input (Tensor) – 源张量
- out (LongTensor, optional) – 包含索引值的结果张量
例子:
>>> torch.nonzero(torch.Tensor([1, 1, 1, 0, 1])) 0 1 2 4 [torch.LongTensor of size 4x1] >>> torch.nonzero(torch.Tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]])) 0 0 1 1 2 2 3 3 [torch.LongTensor of size 4x2]
4 torch.clamp 设置上下阈值
torch.clamp(input, min, max, out=None) → Tensor
将输入input
张量每个元素的夹紧到区间 ([min, max] ),并返回结果到一个新张量。
| min, if x_i < min
y_i = | x_i, if min <= x_i <= max
| max, if x_i > max
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.clamp(a, min=-0.5, max=0.5) 0.5000 0.3912 -0.5000 -0.5000 [torch.FloatTensor of size 4]
5 torch.frac 返回分数部分
torch.frac(tensor, out=None) → Tensor
>>> torch.frac(torch.Tensor([1, 2.5, -3.2])
torch.FloatTensor([0, 0.5, -0.2])
6 torch.lerp 线性插值(一次函数)
torch.lerp(start, end, weight, out=None)
对两个张量以start
,end
做线性插值, 将结果返回到输出张量。
即,( out_i=start_i+weight∗(end_i−start_i) )
参数:
- start (Tensor) – 起始点张量
- end (Tensor) – 终止点张量
- weight (float) – 插值公式的weight
- out (Tensor, optional) – 结果张量
>>> start = torch.arange(1, 5) >>> end = torch.Tensor(4).fill_(10) >>> start 1 2 3 4 [torch.FloatTensor of size 4] >>> end 10 10 10 10 [torch.FloatTensor of size 4] >>> torch.lerp(start, end, 0.5) 5.5000 6.0000 6.5000 7.0000 [torch.FloatTensor of size 4]
以上是关于常用Tensor操作的主要内容,如果未能解决你的问题,请参考以下文章