pytorch 笔记:gather 函数
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:gather 函数相关的知识,希望对你有一定的参考价值。
torch.gather(input, dim, index, out=None) → Tensor
我们直接用例子来说明好了
以三位张量为例
out[i] [j] [k] = tensor[index[i][j][k]] [j] [k] # dim=0 out[i] [j] [k] = tensor[i] [index[i][j][k]] [k] # dim=1 out[i] [j] [k] = tensor[i] [j] [index[i][j][k]] # dim=3
import torch
a = torch.Tensor([[1,2],[3,4]])
torch.gather(a,
0,
index=torch.LongTensor([[0,0],[1,0]]))
'''
tensor([[1., 2.],
[3., 2.]])
'''
这个怎么看呢
out[0][0] | a[index[0][0]] [0]] | a[0][0]=1 |
out[1][0] | a[index[1][0]] [0]] | a[1][0] =3 |
out[0][1] | a[index[0][1]] [1]] | a[0][1]=2 |
out[1][1] | a[index[1][1]] [1]] | a[0][1]=2 |
还有一种用法在每一行选择一个的时候比较常用
import torch
a = torch.Tensor([[1,2,5,6],[3,4,7,8]])
torch.gather(a,1,torch.LongTensor([[0],[3]]))
'''
tensor([[1.],
[8.]])
'''
以上是关于pytorch 笔记:gather 函数的主要内容,如果未能解决你的问题,请参考以下文章