PyTorch中scatter和gather的用法

Posted liuzhan709

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch中scatter和gather的用法相关的知识,希望对你有一定的参考价值。

PyTorch中scatter和gather的用法

闲扯

许久没有更新博客了,2019年总体上看是荒废的,没有做出什么东西,明年春天就要开始准备实习了,虽然不找算法岗的工作,但是还是准备在2019年的最后一个半月认真整理一下自己学习的机器学习和深度学习的知识。

scatter的用法

scatter中文翻译为散射,首先看一个例子来直观感受一下这个API的功能,使用pytorch官网提供的例子。

import torch 
import torch.nn as nn
x = torch.rand(2,5)
x
tensor([[0.2656, 0.5364, 0.8568, 0.5845, 0.2289],
        [0.0010, 0.8101, 0.5491, 0.6514, 0.7295]])
y = torch.zeros(3,5)
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
index
tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])
y.scatter_(dim=0,index=index,src=x)
y
tensor([[0.2656, 0.8101, 0.5491, 0.5845, 0.2289],
        [0.0000, 0.5364, 0.0000, 0.6514, 0.0000],
        [0.0010, 0.0000, 0.8568, 0.0000, 0.7295]])

首先我们可以看到,x的所有值都在y中出现了,且被索引的轴为dim=0,任意一个来自x中的元素,将按照以下公式完成映射。
y[index[i,j],j] = x[i,j],对于x[0,1] = 0.5364,index[0,1] = 1指出这个值将出现在y的第dim=0维,下标为1的位置,因此,y[index[0,1],1] = y[1,1] = x[0,1] = 0.5364.

到这里我们已经对scatter,即散射这个函数有了直观的认识,可用于将一个矩阵映射到一个矩阵,dim指明要映射的轴,index指明要映射的轴的下标,因此对于3D张量,若调用y.scatter_(dim,index,src),那么有:

y[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
y[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
y[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

最后看一个官方文档的关于scatter的英文说明:

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

意思和直观感受几乎相同,函数可将src映射到目标张量self,在dim维度上,由索引index给出下标,在非dim维度上,直接使用src值所在位置的下标。

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

显然self,index,src的ndim应该相同了,否则下标越界了,从公式上看index.size(d) > src.size(d),index.size(d) > self.size(d)没什么问题,index数组可以比src更大,猜测这里是工程上的考虑,因为超出src大小的index数组在这里是没用的,闲置的空间不会被访问。

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

index所有的值需要在[0,self.size(dim) - 1]区间内,这是必须满足的,否则越界了。第二句说沿着dim维的index的所有值需要是唯一的,我测试的结果发现可以重复,看下面的代码:

x = torch.rand(2,5)
x
tensor([[0.6542, 0.6071, 0.7546, 0.4880, 0.1077],
        [0.9535, 0.0992, 0.0594, 0.0641, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.6542, 0.0992, 0.0594, 0.4880, 0.1077],
        [0.0000, 0.6071, 0.0000, 0.0641, 0.0000],
        [0.9535, 0.0000, 0.7546, 0.0000, 0.7563]])
index = torch.tensor([[0,1,2,0,0],[0,1,2,0,0]])
y = torch.zeros(3,5)
y.scatter_(dim=0,index=index,src=x)
tensor([[0.9535, 0.0000, 0.0000, 0.0641, 0.7563],
        [0.0000, 0.0992, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0594, 0.0000, 0.0000]])

可以看到沿着dim=0轴上重复了5次,分别是(0,0),(1,1),(2,2),(0,0),(0,0),代码无报错和警告,只是覆盖掉了原来的值,可能是文档没有更新,但是API更新了。

params:

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

值得注意的是value参数,当没有指明src时,可以指定一个浮点value变量,利用这一点我们实现一个scatter版本的onehot函数。

x = torch.tensor([[1,1,1,1,1]],dtype=torch.float32)
index = torch.tensor([[0,1,2,3,4]],dtype=torch.int64)
y = torch.zeros(5,5,dtype=torch.float32)
x
tensor([[1., 1., 1., 1., 1.]])
y.scatter_(0,index,x)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
y = torch.zeros(5,5,dtype=torch.float32)
y.scatter_(0,index,1)
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

可以看到指定value=1,和src=[[1,1,1,1,1]]等价。到这里scatter就结束了。

gather的用法

gather是scatter的逆过程,从一个张量收集数据,到另一个张量,看一个例子有个直观感受。

x = torch.tensor([[1,2],[3,4]])
torch.gather(input=x,dim=1,index=torch.tensor([[0,0],[1,0]]))
tensor([[1, 1],
        [4, 3]])

可以猜测到收集过程,根据index和dim将x中的数据挑选出来,放置到y中,满足下面的公式:
y[i,j] = x[i,index[i,j]],因此有y[0,0] = x[0,index[0,0]] = x[0,0] = 1, y[1,0] = x[1,index[1,0]] = x[1,1] = 4,对于3D数据,满足以下公式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

到这里gather的用法介绍就结束了,因为gather毕竟是scatter的逆过程,理解了scatter,gather并不需要太多说明。

小结

  1. scatter可以将一个张量映射到另一个张量,其中一个应用是onehot函数.
  2. gather和scatter是两个互逆的过程,gather可用于压缩稀疏张量,收集稀疏张量中非0的元素。
  3. 别再荒废时光了,做不出成果也不能全怪自己的。

以上是关于PyTorch中scatter和gather的用法的主要内容,如果未能解决你的问题,请参考以下文章

Java NIO Scatter/Gather

Java NIO Scatter/Gather

Java NIO系列教程 Scatter/Gather

什么是Scatter/Gather?

Java NIO系列教程 Scatter/Gather

如何使用 MPI_Scatter 和 MPI_Gather 计算多个进程的平均值?