torch.gather函数的理解

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.gather函数的理解相关的知识,希望对你有一定的参考价值。

参考技术A 建议先阅读官方文档,拿笔跟着给出的公式推导一次。

torch.gather官方文档

gather函数的定义为:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
不常用的暂时不关注,于是函数常使用的样子如下:
torch.gather(input, dim, index)
函数的大致功能,给出input,根据dim和index确认从input中取出的数据内容,和最终输出的shape
1.确定输出的shape,输出的shape跟index的shape一致
例如index是一个2x3的tensor,那么输出就是一个2x3的tensor
2.确定输出的tensor的内容:根据input,dim,index三者确定
dim确定的是在input中取数据的时候,那个维度是从index中取数据来确定索引,而不是直接顺序索引:
举例
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

out[i][j] = input[index[i][j][k]][j] # if dim == 0
out[i][j] = input[i][index[i][j][k]] # if dim == 1

其实这个函数一般是结合其他的函数使用的,比如torch.max,torch.min。比如调用max函数如下:
torch.max(a,dim=dim_num)
那么此时返回的是一个tuple,包含了value和index_list
那么有value = torch.gather(a, dim=dim_num, index=index_list)

回调函数的理解

回调函数的理解

前言:

         刚开始用C语言听说过回调函数,但没有仔细去理解,随着工作的慢慢积累,逐步的用到了回调函数,本人认为,“回调函数”的理解对于很多人是一个槛,要想跨过,就得理解清楚,还得会用。这里就用本人的方式讲解一下回调函数如何理解。如有问题,欢迎指正[email protected]

 

第一步:通俗的解释“回调函数”

一、回调就是一种利用函数指针进行函数调用的过程。

二、你到一个商店买东西,刚好你要的东西没有货,于是你在店员那里留下了你的电话,过了几天店里有货了,店员就打了你的电话,然后你接到电话后就到店里去取了货。

三、回调函数是指 使用者自己定义一个函数,实现这个函数的程序内容,然后把这个函数(入口地址)作为参数传入别人(或系统)的函数中,由别人(或系统)的函数在运行时来调用的函数。函数是你实现的,但由别人(或系统)的函数在运行时通过参数传递的方式调用,这就是所谓的回调函数。简单来说,就是由别人的函数运行期间来回调你实现的函数。

注:以上解释均摘自网络,感谢这些大佬的解释。

 

或许到这里你已经明白了那么个意思,但是问题来了:

  1. 怎么看出来就是一个回调函数呢?
  2. 怎么定义一个回调函数呢?
  3. 怎么使用回调函数呢?

带着问题我们看实例,然后再解释。

 

第二步:实例解释 

#include "stdio.h"

int sub(int a, int b)//求和函数
{
    return a+b;
}

int mul(int a, int b)//求积函数
{
    return a*b;
}

int test(int(*p)(int,int), int a, int b)//测试函数
{
    return (*p)(a, b);
}
void main(void) { int a = 8; int b = 2; int temp; temp = test(sub, a, b); printf("%d ", temp);
temp = test(sub, a, b); printf("%d ", temp); } 执行结果就是 10 //8+2 16 //8*2

  

主要分析这个test函数,它的三个输入参数分别是

int(*p)(int, int)         int a          int b

后面两个好理解,就是跟常见的一样,是一个函数的两个输入参数。

第一个要怎么理解呢?

首先看到有*p,这是一个指针了,类比到int *a这种定义方式来理解。

Int temp; //定义一个int型的变量,名称是temp

Int  *a;//定义一个 名称为a的指针这个指针限定的范围是“int型变量”

 

那么 int(*p)(int, int) 它的意思是:定义一个名称为p的函数指针, p指向的函数要求有两个int输入参数,而且要求这个函数返回值是int型。

 

那么趁热打铁 void(*p)(int, int) 的意思就是:定义一个名称为p的函数指针, p指向的函数要求有两个int输入参数,而且要求这个函数返回值是void型(即没有)。

void(*p)(int) 的意思就是:定义一个名称为p的函数指针, p指向的函数要求有一个int输入参数,而且要求这个函数返回值是void型(即没有)。

char(*p)(int) 的意思就是:定义一个名称为p的函数指针, p指向的函数要求有一个int输入参数,而且要求这个函数返回值是char型。

 

那么现在有没有明白这个参数的定义呢?有篇资料中有这么一段

『函数指针的定义比较怪,为什么不是 void ()(int, int, float) *p_func 而是 void (*p_func)(int, int, float) 这种形式?』

这个问题我也不知道,也没必要纠结,花点时间理解下它与普通指针的区别,记住这就是它的定义形式。

 

到这里:总结一下,上面的这么多都是在解释一个“函数指针”。理解了函数指针,我们再看一下

int test(int(*p)(int,int), int a, int b)//测试函数

{

    return (*p)(a, b);

}

这个test函数是把 p 这个函数指针作为一个输入参数。

temp = test(sub, a, b);

这一句呢,就是把p指向sub这个函数, Sub要求有两个int输入参数。

Return (*p)(a,b); 就是 Return sub(a,b);

这里 *p这个参数起作用了,这个参数就是sub函数,就是一个算法。

第三步:疑问解惑

现在来回答第一步的三个问题了:

1、  怎么看出来就是一个回调函数呢?

像test函数那样,用了函数指针作为输入参数的函数就是回调函数,这种函数会调用另外一个函数作为输入参数。

 

2、  怎么定义一个回调函数呢?

回调函数定义跟上面的识别是反过来的,定义一个函数,它的输入参数中有函数指针,那么你就定义了一个回调函数。

 

3、  怎么使用回调函数呢?

使用回调函数的地方有很多,一般用于封装的程序给开发人员留出后期开发接口。

int test(int(*p)(int,int), int a, int b)

{

         Int temp;

           temp = (*p)(a, b);

         if(temp != 0)

                  return OK;

         else

                  return FAIL;

}

这样修改test函数,然后封装起来,test函数只做结果是否不为0的判断,后期开发人员可以把计算函数指针,和参与计算的两个参数输入给test,test就能返回OK或者FAIL。虽然这个test函数已经固定了,但后期开发人员依然可以随便更改计算函数。



以上是关于torch.gather函数的理解的主要内容,如果未能解决你的问题,请参考以下文章

pytorch - torch.gather 的倒数

torch.gather()之通俗易懂讲解

小白学习之pytorch框架-softmax回归(torch.gather()torch.argmax())

pytorch 笔记:gather 函数

pytorch-torch2:张量计算和连接

Pytorch的gather用法理解