如果 cumsum 大于值,则重新启动 cumsum 并获取索引
Posted
技术标签:
【中文标题】如果 cumsum 大于值,则重新启动 cumsum 并获取索引【英文标题】:Restart cumsum and get index if cumsum more than value 【发布时间】:2019-11-16 03:59:57 【问题描述】:假设我有一系列距离x=[1,2,1,3,3,2,1,5,1,1]
。
我想从 x 中获取 cumsum 达到 10 的索引,在本例中为 idx=[4,9]。
所以条件满足后cumsum重新开始。
我可以用循环来做,但是对于大型数组来说,循环很慢,我想知道我是否可以用vectorized
的方式来做。
【问题讨论】:
矢量化方法难以实现 嗯看起来像@WeNYoBen @yatu 我想我之前回答过那些类型的问题 :-) ...更好的解决方案可能使用 numba .. 如果你想在达到 10 时重置cumsum
,这样的方法可以工作:x.cumsum()%10
。如果x
是一个数组,这真的很快。但这也可能不适合你,因为你想如何处理边缘情况还不清楚。比如,如果cumsum
是 11,它应该重置为 0 还是 1?
没有@brenlla 不能很好地工作,因为它没有考虑可能超过 10 的先前值的剩余部分,例如尝试使用 [1,2,1,3,6,2,1,5,1,1]
【参考方案1】:
这是一个带有 numba 和数组初始化的 -
from numba import njit
@njit
def cumsum_breach_numba2(x, target, result):
total = 0
iterID = 0
for i,x_i in enumerate(x):
total += x_i
if total >= target:
result[iterID] = i
iterID += 1
total = 0
return iterID
def cumsum_breach_array_init(x, target):
x = np.asarray(x)
result = np.empty(len(x),dtype=np.uint64)
idx = cumsum_breach_numba2(x, target, result)
return result[:idx]
时间
包括 @piRSquared's solutions
并使用同一篇文章中的基准测试设置 -
In [58]: np.random.seed([3, 1415])
...: x = np.random.randint(100, size=1000000).tolist()
# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop
# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop
# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop
Numba:追加与数组初始化
为了进一步了解数组初始化如何提供帮助,这似乎是两个 numba 实现之间的巨大差异,让我们在数组数据上计时,因为数组数据创建本身就在运行时很重,而且它们都依赖于就可以了-
In [62]: x = np.array(x)
In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop
In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop
要强制输出有自己的内存空间,我们可以制作一个副本。不过不会有大的改变 -
In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop
【讨论】:
这可能不会有太大的不同,但我认为完全公平的是,您应该返回一份result[:idx]
的副本以避免泄露 result[idx:]
的内存。
很好的答案!并不惊讶?只是说
@piRSquared 肯定是受你的激励。认为这种基于数组初始化的方法可以用于 numba 解决方案,当事先不知道输出数组大小时,这些解决方案需要使用附加。而且,数组数据与 numba 配合得更好。所以,这两个是从这次问答中学到的。
是的,很棒的技术。我会用它【参考方案2】:
循环并非总是不好的(尤其是当您需要循环时)。此外,没有任何工具或算法可以比 O(n) 更快。所以让我们做一个好的循环。
生成器函数
def cumsum_breach(x, target):
total = 0
for i, y in enumerate(x):
total += y
if total >= target:
yield i
total = 0
list(cumsum_breach(x, 10))
[4, 9]
使用 Numba 及时编译
Numba 是需要安装的第三方库。 Numba 对于支持哪些功能可能会很挑剔。但这行得通。 此外,正如 Divakar 所指出的,Numba 在数组方面表现更好
from numba import njit
@njit
def cumsum_breach_numba(x, target):
total = 0
result = []
for i, y in enumerate(x):
total += y
if total >= target:
result.append(i)
total = 0
return result
cumsum_breach_numba(x, 10)
测试两者
因为我喜欢它¯\_(ツ)_/¯
设置
np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()
准确度
i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))
assert i0 == i1
时间
%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))
582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numba 大约快 100 倍。
为了更真实的苹果对苹果测试,我将列表转换为 Numpy 数组
%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))
43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
这让他们差不多。
【讨论】:
这应该足够快了。在非负数组的情况下,我们可以在total > 10
时使用break来加快速度。
可能会加速 numba :-) BTW 快乐 7-4
@piRSquared 太棒了,先生 :-)
您好,先生,请您测试一下速度,想知道 frompyfunc 在时间上与 numba 相比有多差。
@piRSquared 转换为数组,然后将其提供给 numba。看看魔法!【参考方案3】:
一个有趣的方法
sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)
(array([4, 9]),)
【讨论】:
检查np.random.seed([3, 1415]); x = np.random.randint(100, size=1_000_000).tolist()
在其他测试中,frompyfunc
往往比更显式的 python 循环快 2 倍。
我没有意识到frompyfunc
使用accumulate
之类的方法生成了ufunc
。
@hpaulj 啊早上才发现 :-)以上是关于如果 cumsum 大于值,则重新启动 cumsum 并获取索引的主要内容,如果未能解决你的问题,请参考以下文章