如果 cumsum 大于值,则重新启动 cumsum 并获取索引

Posted

技术标签:

【中文标题】如果 cumsum 大于值,则重新启动 cumsum 并获取索引【英文标题】:Restart cumsum and get index if cumsum more than value 【发布时间】:2019-11-16 03:59:57 【问题描述】:

假设我有一系列距离x=[1,2,1,3,3,2,1,5,1,1]

我想从 x 中获取 cumsum 达到 10 的索引,在本例中为 idx=[4,9]。

所以条件满足后cumsum重新开始。

我可以用循环来做,但是对于大型数组来说,循环很慢,我想知道我是否可以用vectorized 的方式来做。

【问题讨论】:

矢量化方法难以实现 嗯看起来像@WeNYoBen @yatu 我想我之前回答过那些类型的问题 :-) ...更好的解决方案可能使用 numba .. 如果你想在达到 10 时重置 cumsum,这样的方法可以工作:x.cumsum()%10。如果x 是一个数组,这真的很快。但这也可能不适合你,因为你想如何处理边缘情况还不清楚。比如,如果cumsum 是 11,它应该重置为 0 还是 1? 没有@brenlla 不能很好地工作,因为它没有考虑可能超过 10 的先前值的剩余部分,例如尝试使用 [1,2,1,3,6,2,1,5,1,1] 【参考方案1】:

这是一个带有 numba 和数组初始化的 -

from numba import njit

@njit
def cumsum_breach_numba2(x, target, result):
    total = 0
    iterID = 0
    for i,x_i in enumerate(x):
        total += x_i
        if total >= target:
            result[iterID] = i
            iterID += 1
            total = 0
    return iterID

def cumsum_breach_array_init(x, target):
    x = np.asarray(x)
    result = np.empty(len(x),dtype=np.uint64)
    idx = cumsum_breach_numba2(x, target, result)
    return result[:idx]

时间

包括 @piRSquared's solutions 并使用同一篇文章中的基准测试设置 -

In [58]: np.random.seed([3, 1415])
    ...: x = np.random.randint(100, size=1000000).tolist()

# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop

# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop

# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop

Numba:追加与数组初始化

为了进一步了解数组初始化如何提供帮助,这似乎是两个 numba 实现之间的巨大差异,让我们在数组数据上计时,因为数组数据创建本身就在运行时很重,而且它们都依赖于就可以了-

In [62]: x = np.array(x)

In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop

In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop

要强制输出有自己的内存空间,我们可以制作一个副本。不过不会有大的改变 -

In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop

【讨论】:

这可能不会有太大的不同,但我认为完全公平的是,您应该返回一份 result[:idx] 的副本以避免泄露 result[idx:] 的内存。 很好的答案!并不惊讶?只是说 @piRSquared 肯定是受你的激励。认为这种基于数组初始化的方法可以用于 numba 解决方案,当事先不知道输出数组大小时,这些解决方案需要使用附加。而且,数组数据与 numba 配合得更好。所以,这两个是从这次问答中学到的。 是的,很棒的技术。我会用它【参考方案2】:

循环并非总是不好的(尤其是当您需要循环时)。此外,没有任何工具或算法可以比 O(n) 更快。所以让我们做一个好的循环。

生成器函数

def cumsum_breach(x, target):
    total = 0
    for i, y in enumerate(x):
        total += y
        if total >= target:
            yield i
            total = 0

list(cumsum_breach(x, 10))

[4, 9]

使用 Numba 及时编译

Numba 是需要安装的第三方库。 Numba 对于支持哪些功能可能会很挑剔。但这行得通。 此外,正如 Divakar 所指出的,Numba 在数组方面表现更好

from numba import njit

@njit
def cumsum_breach_numba(x, target):
    total = 0
    result = []
    for i, y in enumerate(x):
        total += y
        if total >= target:
            result.append(i)
            total = 0

    return result

cumsum_breach_numba(x, 10)

测试两者

因为我喜欢它¯\_(ツ)_/¯

设置

np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()

准确度

i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))

assert i0 == i1

时间

%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))

582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numba 大约快 100 倍。

为了更真实的苹果对苹果测试,我将列表转换为 Numpy 数组

%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))

43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

这让他们差不多。

【讨论】:

这应该足够快了。在非负数组的情况下,我们可以在total > 10时使用break来加快速度。 可能会加速 numba :-) BTW 快乐 7-4 @piRSquared 太棒了,先生 :-) 您好,先生,请您测试一下速度,想知道 frompyfunc 在时间上与 numba 相比有多差。 @piRSquared 转换为数组,然后将其提供给 numba。看看魔法!【参考方案3】:

一个有趣的方法

sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)

(array([4, 9]),)

【讨论】:

检查np.random.seed([3, 1415]); x = np.random.randint(100, size=1_000_000).tolist() 在其他测试中,frompyfunc 往往比更显式的 python 循环快 2 倍。 我没有意识到frompyfunc 使用accumulate 之类的方法生成了ufunc @hpaulj 啊早上才发现 :-)

以上是关于如果 cumsum 大于值,则重新启动 cumsum 并获取索引的主要内容,如果未能解决你的问题,请参考以下文章

批处理文件:如果 xxxxxx 大于 xxKB,则启动文件

如果我在插入后重新启动识别器列,则脚本文件中的错误

如果在之前进行预处理,则数据处理时间太长

pytorch torch.cumsum(input, dim, out=None)函数(沿轴逐级累加)

如果值大于0,则过滤不同[重复]

如果列大于某个值,则更新列