如何找到 Numpy 数组的 M 个元素的 N 个最大乘积子数组?
Posted
技术标签:
【中文标题】如何找到 Numpy 数组的 M 个元素的 N 个最大乘积子数组?【英文标题】:How to find N maximum product subarrays of M elements of a Numpy array? 【发布时间】:2020-09-06 22:32:25 【问题描述】:我有一个 Numpy 数组,我需要找到 M 个元素的 N 个最大乘积子数组。例如,我有数组p = [0.1, 0.2, 0.8, 0.5, 0.7, 0.9, 0.3, 0.5]
,我想找到 3 个元素的 5 个最高乘积子数组。有没有“快速”的方法来做到这一点?
【问题讨论】:
那不是一个numpy数组 发布的解决方案是否对您有用? 【参考方案1】:这是另一种快速的方法:
import numpy as np
p = [0.1, 0.2, 0.8, 0.5, 0.7, 0.9, 0.3, 0.5]
n = 5
m = 3
# Cumulative product (starting with 1)
pc = np.cumprod(np.r_[1, p])
# Cumulative product of each window
w = pc[m:] / pc[:-m]
# Indices of the first element of top N windows
idx = np.argpartition(w, n)[-n:]
print(idx)
# [1 2 5 4 3]
【讨论】:
这个主意不错。只是如果数组大小合适,并且其中包含分数,它最终会在后端返回零。 @Divakar 是的,这很好,如果数组足够大,那么精度可能会受到影响,如果不是,那么性能可能无论如何都不是问题。【参考方案2】:方法#1
我们可以创建滑动窗口,然后执行prod
缩减,最后执行np.argpartition
以得到其中最上面的N
-
from skimage.util.shape import view_as_windows
def topN_windowed_prod(a, W, N):
w = view_as_windows(a,W)
return w[w.prod(1).argpartition(-N)[-N:]]
示例运行 -
In [2]: p = np.array([0.1, 0.2, 0.8, 0.5, 0.7, 0.9, 0.3, 0.5])
In [3]: topN_windowed_prod(p, W=3, N=2)
Out[3]:
array([[0.8, 0.5, 0.7],
[0.5, 0.7, 0.9]])
请注意,np.argpartition
不维护订单。因此,如果我们需要按 prod
值的降序排列的顶部 N
,请使用 range(N)
。 More info.
方法 #2
对于较小的窗口长度,我们可以简单地切片并获得我们想要的结果,就像这样 -
def topN_windowed_prod_with_slicing(a, W, N):
w = view_as_windows(a,W)
L = len(a)-W+1
acc = a[:L].copy()
for i in range(1,W):
acc *= a[i:i+L]
idx = acc.argpartition(-N)[-N:]
return w[idx]
【讨论】:
以上是关于如何找到 Numpy 数组的 M 个元素的 N 个最大乘积子数组?的主要内容,如果未能解决你的问题,请参考以下文章
我们能否找到元素是不是存在于数组 1,2,...,n 中,其中元素 m 个不同的元素在 Θ(m) 中? [关闭]