如何在给定索引列表的情况下有效地更新 numpy ndarray
Posted
技术标签:
【中文标题】如何在给定索引列表的情况下有效地更新 numpy ndarray【英文标题】:How to efficiently update a numpy ndarray given a list of indices 【发布时间】:2022-01-12 21:38:12 【问题描述】:我有一个名为 new_arr 的 4 维数组。给定一个索引列表,我想根据我存储的旧数组 old_arr 更新 new_arr。我正在使用 for 循环来执行此操作,但效率低下。我的代码如下所示:
update_indices = [(2,33,1,8), (4,9,49,50), ...] #as an example
for index in update_indices:
i,j,k,l = index
new_arr[i][j][k][l] = old_arr[i][j][k][l]
这需要很长时间,因为 update_indices 很大。有没有一种方法可以一次更新所有条款或更有效地更新?
【问题讨论】:
你的数组是什么形状的? 它是 (57, 57, 57, 57) 唯一明显的就是直接使用index
元组。在简短的测试中,[index]
与 [index[0]][index[1]]
相比,我得到了约 2 倍的加速。
new_arr[index] = old_arr[index]
为了使上述 cmets 中提出的索引更快,您需要 update_indices
是一个 Numpy 数组,而不是一个纯 Python 元组列表。这意味着它应该直接生成为 Numpy 数组(因为转换也很昂贵)。
【参考方案1】:
只要做:
idx = np.array([(2,33,1,8), (4,9,49,50), ...])
new_arr[idx[:,0],idx[:,1],idx[:,2],idx[:,3]] = old_arr[idx[:,0],idx[:,1],idx[:,2],idx[:,3]]
不需要循环。
【讨论】:
你检查过这不会改变语义吗?根据我对 numpy 索引的理解,这应该不起作用。 你是对的,更新了我的答案 但我可以看到你的答案中已经包含了那个【参考方案2】:出于好奇,我对 cmets 中发布的各种改进进行了基准测试,发现使用平面索引是最快的。
我使用了以下设置:
rt numpy as np
n = 57
d = 4
k = int(1e6)
dt = np.double
new_arr = np.arange(n**d, dtype=dt).reshape(d * (n,))
new_arr2 = np.arange(n**d, dtype=dt).reshape(d * (n,))
old_arr = 2*np.arange(n**d, dtype=dt).reshape(d * (n,))
update_indices = list(tuple(np.random.randint(n, size=d)) for _ in range(int(k*1.1)))[:k]
其中update_indices
是 1e6 个唯一索引元组的列表。
使用问题中的原始技术
%%timeit
for index in update_indices:
i,j,k,l = index
new_arr[i][j][k][l] = old_arr[i][j][k][l]
接受1.47 s ± 19.3 ms
。
@defladamouse 建议的直接元组索引
%%timeit
for index in update_indices:
new_arr[index] = old_arr[index]
确实给了我们 2 倍的加速:778 ms ± 41.8 ms
如果没有给出update_indices
,但可以按照@Jérôme Richard 的建议构造为ndarray
update_indices_array = np.array(update_indices, dtype=np.uint32)
(转换本身需要1.34 s
)实现更快的路径是开放的。
为了通过多维位置列表索引 numpy 数组,我们不能直接使用 update_indices_array
作为索引,而是将其列打包成一个元组:
%%timeit
idx = tuple(update_indices_array.T)
new_arr2[idx] = old_arr[idx]
再给出大约 9 倍的加速:83.5 ms ± 1.45
如果我们不将内存偏移的计算留给ndarray.__getitem__
,
但是“手动”计算对应的平面索引,我们可以变得更快:
%%timeit
idx_weights = np.cumprod((1,) + new_arr2.shape[:0:-1])[::-1]
update_flat = update_indices_array @ idx_weights
new_arr2.ravel()[update_flat] = old_arr.ravel()[update_flat]
导致41.6 ms ± 1.04 ms
,与原始版本相比,又增加了 2 倍和累积加速因子 35。
idx_weights
只是数组维度的一个逆序累积乘积。
我假设这个 2 的加速来自这样一个事实,即内存偏移量/平面索引在 new_arr2[idx] = old_arr[idx]
中计算了两次,而在 update_flat = update_indices_array @ idx_weight
中只计算了一次。
【讨论】:
以上是关于如何在给定索引列表的情况下有效地更新 numpy ndarray的主要内容,如果未能解决你的问题,请参考以下文章
如何在给定键和值数组的情况下有效地在 Eloquent 中进行大规模更新
如何在没有科学记数法和给定精度的情况下漂亮地打印 numpy.array?
使用 NumPy 从矩阵中获取最小/最大 n 值和索引的有效方法