为啥循环在这里比索引好?

Posted

技术标签:

【中文标题】为啥循环在这里比索引好?【英文标题】:Why Does Looping Beat Indexing Here?为什么循环在这里比索引好? 【发布时间】:2011-03-31 22:05:35 【问题描述】:

几年前,有人posted 在Active State Recipes 上进行比较,三个python/NumPy 函数;它们中的每一个都接受相同的参数并返回相同的结果,即距离矩阵

其中两个取自已发布的资源;它们都是——或者在我看来它们都是——惯用的 numpy 代码。创建距离矩阵所需的重复计算由 numpy 优雅的索引语法驱动。这是其中之一:

from numpy.matlib import repmat, repeat

def calcDistanceMatrixFastEuclidean(points):
  numPoints = len(points)
  distMat = sqrt(sum((repmat(points, numPoints, 1) - 
             repeat(points, numPoints, axis=0))**2, axis=1))
  return distMat.reshape((numPoints,numPoints))

第三个使用单个循环创建了距离矩阵(考虑到只有 1,000 个 2D 点的距离矩阵有 100 万个条目,这显然是很多循环)。乍一看,这个函数在我看来就像我在学习 NumPy 时编写的代码,我会先编写 Python 代码然后逐行翻译来编写 NumPy 代码。

在 Active State 发布几个月后,比较这三者的性能测试结果在 NumPy 邮件列表上的 thread 中发布和讨论。

带有循环的函数实际上明显优于其他两个:

from numpy import mat, zeros, newaxis

def calcDistanceMatrixFastEuclidean2(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  delta = zeros((n,n),'d')
  for d in xrange(m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

线程中的一位参与者 (Keir Mierle) 提出了这可能是真的原因:

我怀疑这会更快的原因是 它具有更好的局部性,完全完成了一个计算 在进入下一个之前的相对较小的工作集。一个衬垫 必须反复将可能很大的 MxN 数组拉入处理器。

根据发帖者本人的说法,他的言论只是怀疑,似乎没有进一步讨论。

关于如何解释这些结果的任何其他想法?

特别是,是否有一个有用的规则——关于何时循环和何时索引——可以从这个例子中提取出来作为编写 numpy 代码的指导?

对于那些不熟悉 NumPy 或者没有看过代码的人来说,这种比较并不是基于边缘情况——如果是这样的话,我肯定不会那么感兴趣。相反,这种比较涉及一个在矩阵计算中执行常见任务的函数(即,在给定两个前项的情况下创建一个结果数组);此外,每个函数又由最常见的 numpy 内置函数组成。

【问题讨论】:

【参考方案1】:

dis 好玩:

dis.dis(calcDistanceMatrixFastEuclidean)

  2           0 LOAD_GLOBAL              0 (len)
              3 LOAD_FAST                0 (points)
              6 CALL_FUNCTION            1
              9 STORE_FAST               1 (numPoints)

  3          12 LOAD_GLOBAL              1 (sqrt)
             15 LOAD_GLOBAL              2 (sum)
             18 LOAD_GLOBAL              3 (repmat)
             21 LOAD_FAST                0 (points)
             24 LOAD_FAST                1 (numPoints)
             27 LOAD_CONST               1 (1)
             30 CALL_FUNCTION            3

  4          33 LOAD_GLOBAL              4 (repeat)
             36 LOAD_FAST                0 (points)
             39 LOAD_FAST                1 (numPoints)
             42 LOAD_CONST               2 ('axis')
             45 LOAD_CONST               3 (0)
             48 CALL_FUNCTION          258
             51 BINARY_SUBTRACT
             52 LOAD_CONST               4 (2)
             55 BINARY_POWER
             56 LOAD_CONST               2 ('axis')
             59 LOAD_CONST               1 (1)
             62 CALL_FUNCTION          257
             65 CALL_FUNCTION            1
             68 STORE_FAST               2 (distMat)

  5          71 LOAD_FAST                2 (distMat)
             74 LOAD_ATTR                5 (reshape)
             77 LOAD_FAST                1 (numPoints)
             80 LOAD_FAST                1 (numPoints)
             83 BUILD_TUPLE              2
             86 CALL_FUNCTION            1
             89 RETURN_VALUE

dis.dis(calcDistanceMatrixFastEuclidean2)

  2           0 LOAD_GLOBAL              0 (array)
              3 LOAD_FAST                0 (nDimPoints)
              6 CALL_FUNCTION            1
              9 STORE_FAST               0 (nDimPoints)

  3          12 LOAD_FAST                0 (nDimPoints)
             15 LOAD_ATTR                1 (shape)
             18 UNPACK_SEQUENCE          2
             21 STORE_FAST               1 (n)
             24 STORE_FAST               2 (m)

  4          27 LOAD_GLOBAL              2 (zeros)
             30 LOAD_FAST                1 (n)
             33 LOAD_FAST                1 (n)
             36 BUILD_TUPLE              2
             39 LOAD_CONST               1 ('d')
             42 CALL_FUNCTION            2
             45 STORE_FAST               3 (delta)

  5          48 SETUP_LOOP              76 (to 127)
             51 LOAD_GLOBAL              3 (xrange)
             54 LOAD_FAST                2 (m)
             57 CALL_FUNCTION            1
             60 GET_ITER
        >>   61 FOR_ITER                62 (to 126)
             64 STORE_FAST               4 (d)

  6          67 LOAD_FAST                0 (nDimPoints)
             70 LOAD_CONST               0 (None)
             73 LOAD_CONST               0 (None)
             76 BUILD_SLICE              2
             79 LOAD_FAST                4 (d)
             82 BUILD_TUPLE              2
             85 BINARY_SUBSCR
             86 STORE_FAST               5 (data)

  7          89 LOAD_FAST                3 (delta)
             92 LOAD_FAST                5 (data)
             95 LOAD_FAST                5 (data)
             98 LOAD_CONST               0 (None)
            101 LOAD_CONST               0 (None)
            104 BUILD_SLICE              2
            107 LOAD_GLOBAL              4 (newaxis)
            110 BUILD_TUPLE              2
            113 BINARY_SUBSCR
            114 BINARY_SUBTRACT
            115 LOAD_CONST               2 (2)
            118 BINARY_POWER
            119 INPLACE_ADD
            120 STORE_FAST               3 (delta)
            123 JUMP_ABSOLUTE           61
        >>  126 POP_BLOCK

  8     >>  127 LOAD_GLOBAL              5 (sqrt)
            130 LOAD_FAST                3 (delta)
            133 CALL_FUNCTION            1
            136 RETURN_VALUE

我不是dis 方面的专家,但您似乎必须更多地查看第一个调用的函数才能知道它们为什么需要一段时间。还有一个带有 Python 的性能分析器工具,cProfile

【讨论】:

如果你使用cProfile,我建议使用RunSnakeRun查看结果。 我注意到 Python 优化的技巧似乎通常是让 Python 解释器执行尽可能少的 Python 指令。【参考方案2】:

TL; DR 上面的第二个代码只是循环点的维数(通过 for 循环 3 次 3D 点),所以循环并不多。上面第二个代码中真正的加速是它更好地利用了 Numpy 的强大功能,以避免在查找点之间的差异时创建一些额外的矩阵。这减少了内存使用和计算量。

更长的解释 我认为calcDistanceMatrixFastEuclidean2 函数可能会用它的循环来欺骗你。它只是循环点的维数。对于 1D 点,循环只执行一次,对于 2D,执行两次,对于 3D,执行三次。这真的没有太多循环。

让我们稍微分析一下代码,看看为什么一个比另一个快。 calcDistanceMatrixFastEuclidean 我会打电话给fast1calcDistanceMatrixFastEuclidean2 将是fast2

fast1 基于 Matlab 的处理方式,repmap 函数证明了这一点。在这种情况下,repmap 函数会创建一个数组,它只是一遍又一遍地重复的原始数据。但是,如果您查看该函数的代码,则效率非常低。它使用许多 Numpy 函数(3 个reshapes 和 2 个repeats)来执行此操作。 repeat 函数还用于创建一个包含原始数据的数组,其中每个数据项重复多次。如果我们的输入数据是[1,2,3],那么我们从[1,1,1,2,2,2,3,3,3] 中减去[1,2,3,1,2,3,1,2,3]。 Numpy 不得不在运行 Numpy 的 C 代码之间创建许多额外的矩阵,这本来是可以避免的。

fast2 使用了更多 Numpy 的繁重工作,而不会在 Numpy 调用之间创建尽可能多的矩阵。 fast2 循环遍历点的每个维度,进行减法并保持每个维度之间平方差的总和。只有在最后才算平方根。到目前为止,这听起来可能不如fast1 高效,但fast2 通过使用Numpy 的索引避免了repmat 的事情。为简单起见,让我们看一下 1D 案例。 fast2 制作一维数据数组,然后从二维 (N x 1) 数据数组中减去它。这会在每个点和所有其他点之间创建差异矩阵,而无需使用 repmatrepeat,从而绕过创建大量额外数组。这就是我认为真正的速度差异所在。 fast1 在矩阵之间创建了很多额外的内容(并且它们的计算成本很高)来找到点之间的差异,而 fast2 更好地利用 Numpy 的力量来避免这些。

顺便说一下,这里是fast2的稍微快一点的版本:

def calcDistanceMatrixFastEuclidean3(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  data = nDimPoints[:,0]
  delta = (data - data[:,newaxis])**2
  for d in xrange(1,m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

不同之处在于我们不再将增量创建为零矩阵。

【讨论】:

以上是关于为啥循环在这里比索引好?的主要内容,如果未能解决你的问题,请参考以下文章

新手提问 python for循环问题 print (y) #这里为啥只输出一行?

为啥即使 loop='true' 我也没有嵌入音频循环?

为啥这个回调会产生无限循环

没有退出“while”循环,但为啥呢?

为啥不建议数组使用 JavaScript 的 For...In 循环? [复制]

在此代码中,对于 while 循环内的条件,错误显示“除以零”。为啥会这样?