快速 numpy 花式索引
Posted
技术标签:
【中文标题】快速 numpy 花式索引【英文标题】:Fast numpy fancy indexing 【发布时间】:2013-01-01 10:46:40 【问题描述】:我的 numpy 数组切片代码(通过花哨的索引)非常慢。目前是程序的瓶颈。
a.shape
(3218, 6)
ts = time.time(); a[rows][:, cols]; te = time.time(); print('%.8f' % (te-ts));
0.00200009
什么是正确的 numpy 调用来获取由矩阵 a 的行“rows”和列“col”的子集组成的数组? (其实我需要这个结果的转置)
【问题讨论】:
time.time
不是衡量时间的好方法。一般来说,最好使用timeit
。
1.你的程序到底在做什么? 2.使用适当的python分析器。我发现切片不太可能是您的瓶颈
Oren - 如果您使用 @mgilson
样式提及,它将向用户发送通知(每个评论一个)。
@mgilson:我记得在某些情况下(不过是 4 年前)遇到了问题,可能不再适用。手册说以下For all cases of index arrays, what is returned is a copy of the original data, not a view as one gets for slices
docs.scipy.org/doc/numpy/user/…
@Wolph Numpy 1.15 仍然如此:高级索引总是返回数据的副本(与返回视图的基本切片相反)
【参考方案1】:
让我试着总结一下 Jaime 和 TheodrosZelleke 的优秀答案,并混合一些 cmets。
-
Advanced (fancy) indexing 总是返回一个副本,而不是一个视图。
a[rows][:,cols]
意味着 两个 奇特的索引操作,因此创建并丢弃了一个中间副本 a[rows]
。方便且可读,但效率不高。此外请注意[:,cols]
通常会从 C-cont 生成 Fortran 连续副本。来源。
a[rows.reshape(-1,1),cols]
是一个高级索引表达式,它基于 rows.reshape(-1,1)
和 cols
是 broadcast 的预期结果的形状。
一个常见的经验是,扁平化数组中的索引可能比花式索引更有效,因此另一种方法是
indx = rows.reshape(-1,1)*a.shape[1] + cols
a.take(indx)
或
a.take(indx.flat).reshape(rows.size,cols.size)
效率取决于内存访问模式以及起始数组是 C 计数还是 Fortran 连续,因此需要进行实验。
仅在真正需要时才使用花哨的索引:basic slicing a[rstart:rstop:rstep, cstart:cstop:cstep]
返回一个视图(尽管不是连续的)并且应该更快!
【讨论】:
【参考方案2】:令我惊讶的是,这种计算第一个线性一维索引的冗长表达式比问题中提出的连续数组索引快 50%:
(a.ravel()[(
cols + (rows * a.shape[1]).reshape((-1,1))
).ravel()]).reshape(rows.size, cols.size)
更新: OP 更新了初始数组形状的描述。使用更新后的大小,加速现在高于 99%:
In [93]: a = np.random.randn(3218, 1415)
In [94]: rows = np.random.randint(a.shape[0], size=2000)
In [95]: cols = np.random.randint(a.shape[1], size=6)
In [96]: timeit a[rows][:, cols]
10 loops, best of 3: 186 ms per loop
In [97]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 1.56 ms per loop
初始答案: 以下是实录:
In [79]: a = np.random.randn(3218, 6)
In [80]: a.shape
Out[80]: (3218, 6)
In [81]: rows = np.random.randint(a.shape[0], size=2000)
In [82]: cols = np.array([1,3,4,5])
时间法一:
In [83]: timeit a[rows][:, cols]
1000 loops, best of 3: 1.26 ms per loop
时间法2:
In [84]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 568 us per loop
检查结果是否实际相同:
In [85]: result1 = a[rows][:, cols]
In [86]: result2 = (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
In [87]: np.sum(result1 - result2)
Out[87]: 0.0
【讨论】:
对于 OP 的新要求,它也比我提供的更标准的答案快两倍。不错! 并不是说这样的技巧不能加快速度(至少在特定情况下),但所有这一切都在很大程度上依赖于输入数组是 C 连续的事实。 不足为奇:看到这个answer到一个相关的问题。【参考方案3】:如果您使用精美的索引和广播进行切片,您可以获得一些速度:
from __future__ import division
import numpy as np
def slice_1(a, rs, cs) :
return a[rs][:, cs]
def slice_2(a, rs, cs) :
return a[rs[:, None], cs]
>>> rows, cols = 3218, 6
>>> rs = np.unique(np.random.randint(0, rows, size=(rows//2,)))
>>> cs = np.unique(np.random.randint(0, cols, size=(cols//2,)))
>>> a = np.random.rand(rows, cols)
>>> import timeit
>>> print timeit.timeit('slice_1(a, rs, cs)',
'from __main__ import slice_1, a, rs, cs',
number=1000)
0.24083110865
>>> print timeit.timeit('slice_2(a, rs, cs)',
'from __main__ import slice_2, a, rs, cs',
number=1000)
0.206566124519
如果您从百分比的角度来考虑,将速度提高 15% 总是好的,但在我的系统中,对于您的阵列大小,执行切片所需的时间减少了 40 us,很难相信需要 240 us 的操作将成为您的瓶颈。
【讨论】:
原来我有一个 3218x1415 数组,而不是 3218x6。我只提取几列和很多行。上面的代码显示 slice_1 的调用时间是 0.08 秒, slice_2 的调用时间是 0.0004 秒。也许这就是我需要的!【参考方案4】:使用np.ix_
,您可以达到与 ravel/reshape 相似的速度,但代码更清晰:
a = np.random.randn(3218, 1415)
rows = np.random.randint(a.shape[0], size=2000)
cols = np.random.randint(a.shape[1], size=6)
a = np.random.randn(3218, 1415)
rows = np.random.randint(a.shape[0], size=2000)
cols = np.random.randint(a.shape[1], size=6)
%timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
#101 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit ix_ = np.ix_(rows, cols); a[ix_]
#135 µs ± 7.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
ix_ = np.ix_(rows, cols)
result1 = a[ix_]
result2 = (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
np.sum(result1 - result2)
0.0
【讨论】:
以上是关于快速 numpy 花式索引的主要内容,如果未能解决你的问题,请参考以下文章