pytorch 笔记:gather 函数

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:gather 函数相关的知识,希望对你有一定的参考价值。

torch.gather(input, dim, index, out=None) → Tensor

我们直接用例子来说明好了

以三位张量为例

out[i] [j] [k] = tensor[index[i][j][k]]  [j]  [k] 
# dim=0

out[i] [j] [k] = tensor[i]  [index[i][j][k]]  [k] 
# dim=1

out[i] [j] [k] = tensor[i]  [j]  [index[i][j][k]] 
# dim=3
import torch
a = torch.Tensor([[1,2],[3,4]])
torch.gather(a,
            0,
            index=torch.LongTensor([[0,0],[1,0]]))
'''
tensor([[1., 2.],
        [3., 2.]])
'''

这个怎么看呢

out[0][0]a[index[0][0]]  [0]]a[0][0]=1
out[1][0]a[index[1][0]]  [0]]a[1][0] =3
out[0][1]a[index[0][1]]  [1]]   a[0][1]=2
out[1][1]a[index[1][1]]  [1]]a[0][1]=2

还有一种用法在每一行选择一个的时候比较常用

import torch
a = torch.Tensor([[1,2,5,6],[3,4,7,8]])
torch.gather(a,1,torch.LongTensor([[0],[3]]))
'''
tensor([[1.],
        [8.]])
'''

以上是关于pytorch 笔记:gather 函数的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中gather函数的理解。

pytorch - torch.gather 的倒数

torch.gather()之通俗易懂讲解

标签平滑Label Smoothing Demo(附pytorch的NLLLoss(),gather())

PyTorch中scatter和gather的用法

Pytorch的gather用法理解