pytorch 的 sum 和 softmax 方法 dim 参数的使用

Posted yhjoker

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 的 sum 和 softmax 方法 dim 参数的使用相关的知识,希望对你有一定的参考价值。

  在阅读使用 pytorch 实现的代码时,笔者会遇到需要对某一维数据进行求和( sum )或 softmax 的操作。在 pytorch 中,上述两个方法均带有一个指定维度的 dim 参数,这里记录下 dim 参数的用法。

 

  torch.sum

  在 pytorch 中,提供 torch.sum 的两种形式,一种直接将待求和数据作为参数,则返回参数数据所有维度所有元素的和,另外一种除接收待求和数据作为参数外,还可加入 dim 参数,指定对待求和数据的某一维进行求和。

    out = torch.sum( a )                   #对 a 中所有元素求和
    out = torch.sum( a , dim = 1 )         #对 a 中第 1 维的元素求和

  上述第一种形式比较好理解,但第二种形式,加入 dim 参数后,比较令人疑惑的是到底哪些元素参与了求和?这里通过例子来进行说明。

  1)首先我们生成一个维度为 ( 3, 4, 5, 6 ) 的元素全为 1.0 的 tensor a。

    >>> import torch
    >>> a = torch.ones( 3, 4, 5, 6 )         #生成一个形状为 ( 3, 4, 5, 6 ) 的数据,数据类型默认为 torch.FloatTensor

  2)使用 sum 方法对上述生成的 tensor 进行求和操作。注意 tensor 的维度索引从 0 开始。

    >>> b = torch.sum( a )               #对 a 中所有元素求和, b = 360.0
    >>> c = torch.sum( a, dim = 0 )      #对 a 中 dim = 0 元素求和
    >>> c.shape                          # c 的 shape 为 torch.Size( [ 4, 5, 6 ] ),其中所有元素值为 3.0
    >>> d = torch.sum( a, dim = 3 )      #对 a 中 dim = 3 元素求和
    >>> d.shape                          # d 的 shape 为 torch.Size( [ 3, 4, 5 ] ),其中所有元素值为 6.0

  对上述结果进行解释,b 的结果很好理解,因为 tensor a 的维度为 ( 3, 4, 5, 6 ) 且其中所有元素的值为 1,则对其中所有元素求和的结果为 3 * 4 * 5 * 6 * 1.0 = 360.0 .

  对于 c 和 d 的结果,首先可以观察得到的是, 若在第 i 维进行求和,即 sum( a, dim = i ),则求和结果的每一个元素的值均为该维度的大小。如在 dim = 0 求和,在 dim = 0 上 a 的尺寸为 3,则求和结果 c 的每一个元素值为 3.0 .也就是说每个结果元素值均为是三个求和元素值( 1.0 )相加的结果,求和结果 c 的维度为 ( 4, 5, 6 ),说明待求和数据 a 分为 ( 4, 5, 6 ) 共 4 * 5 * 6 组的元素进行了求和运算。在 dim = 3 上的求和结果 d 现象与 c 保持一致。

  对于输入待求和数据所有数据元素均为 1 时,可以归纳出一个结论,对于维度为 ( s0, s1, s2, s3 ) 的 tensor 的第 i 维进行求和,如第 0 维,则结果的维度为 ( s1, s2, s3 ),其维度为原输入维度去除求和维度。结果的每一个元素值即为 1 * s0 = s0,即为待求和维度的尺寸。

  下面以三维数据即维度为 ( 3, 4, 4 ) 的 tensor a 为例展示 sum 在某一维度的实际计算过程。

                      技术图片

  使用 dim = 0 参数计算时,产生的结果维度为 ( 4, 4 ), 对于结果中的每一个位置 ( i, j ) ,由 3 个元素进行计算,实际计算的是 a[ 0 ][ i ][ j ] + a[ 1 ][ i ][ j ] + a[ 2 ][ i ][ j ],当上述三个元素的值均为 1.0 时,计算结果元素即为 3.0 。如上图左侧的图,a[ 0 ][ 3 ][ 3 ] + a[ 1 ][ 3 ][ 3 ] + a[ 2 ][ 3 ][ 3 ] 的结果即为输出 ( 3, 3 ) 位置上的值。上述位置索引 ( i, j ) 的数量由输入的待求和数据的其他维度的尺寸决定。 

  使用 dim = 2 参数计算时,产生的结果维度为 ( 3, 4 ),对于结果中的每一个位置( i, j ) ,由 4 个元素进行计算,实际计算的是 a[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] + a[ i ][ j ][ 2 ] + a[ i ][ j ][ 3 ],当上述四个元素的值均为 1.0 时,计算结果元素即为 4.0 。如  a[ 0 ][ 0 ][ 0 ] + a[ 0 ][ 0 ][ 1 ] + a[ 0 ][ 0 ][ 2 ] + a[ 0 ][ 0 ][ 3 ] 即为输出 ( 0, 0 ) 位置上的值。

  对于维度为 ( s0, s1, s2, ... , si, ... , sn ) 的待求和向量,使用 dim = i 调用 sum 方法,则实际产生的结果维度为 ( s0, s1, s2, ... , si-1, si+1, ... , sn ),每个结果元素由 si 个元素元素求和获得。这 si 个元素坐标在其他维度索引保持一致,而在待求和维度索引由 0 至 si 变化。可以看到共有 ( s0, s1, s2, ... , si-1, si+1, ... , sn ) 组这样的求和元素( 索引的数量 ),即为结果的维度。

 

  torch.nn.softmax / torch.nn.functional.softmax

  softmax 是神经网路中常见的一种计算函数,其可将所有参与计算的对象值映射到 0 到 1 之间,并使得计算对象的和为 1. 在 pytorch 中的 softmax 方法在使用时也需要通过 dim 方法来指定具体进行 softmax 计算的维度。这里以 torch.nn.functional.softmax 为例进行说明。

  softmax 在 pytorch 官方文档中的描述如下:

  It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.  

  可以明确的是, softmax 计算获得的数值在 0 - 1 之间,但是同样比较令人疑惑的是,all slices along the dim 具体指代的是那些数据。这里使用一个维度为 ( 2, 2, 2 ) 的 tensor a 作为示例。

    >>> import torch
    >>> import torch.nn.functional as f
    >>> a = torch.ones( 2, 2, 2 )
    >>> b = f.softmax( a, dim=0 )           #对 a 的第 0 维进行 softmax 计算

  与 sum 方法不同,softmax 方法计算获得的结果的维度与输入的待计算的数据的维度保持一致( sum 方法求和后进行指定求和的那一维不会出现在结果维度中 )。

  参与 softmax 计算的元素与 sum 方法很相似,对于 tensor a 在 dim = 0 进行 softmax,输出结果 b 实际上是 b[ 0 ][ i ][ j ] + b[ 1 ][ i ][ j ] 的值为 1.即其他维度索引保持一致,而在进行 softmax 维度索引由 0 至 si 变化,如 b[ 0 ][ 0 ][ 1 ] + a[ 1 ][ 0 ][ 1 ] 的值为1.对于 tensor a 在 dim = 2 进行 softmax,输出结果 b 实际上是 b[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] 的值为 1.

    >>> c = b[ 0 ][ 0 ][ 1 ] + b[ 1 ][ 0 ][ 1 ]        #c 的值为 1

 

  参考

  pytorch - tensor-creation-ops

  pytorch - torch.tensor

  pytorch - torch.nn.functional

以上是关于pytorch 的 sum 和 softmax 方法 dim 参数的使用的主要内容,如果未能解决你的问题,请参考以下文章

Softmax回归的简洁实现(softmax-regression-pytorch)

Pytorch - 使用一种热编码和 softmax 的(分类)交叉熵损失

Pytorch softmax:使用啥维度?

《计算机视觉和图像处理简介 - 中英双语版》:基于PyTorch Softmax 进行 MNIST 手写数字分类Digit Classification with Softmax

《动手学深度学习》softmax回归(PyTorch版)

《动手学深度学习》softmax回归(PyTorch版)