用另一个数组切片 numpy 数组
Posted
技术标签:
【中文标题】用另一个数组切片 numpy 数组【英文标题】:Slicing numpy array with another array 【发布时间】:2012-09-17 09:10:15 【问题描述】:我有一个大的一维整数数组,我需要将切片去掉。这很简单,我会做a[start:end]
。问题是我需要更多这些切片。如果 start 和 end 是数组,a[start:end]
不起作用。可以为此使用 for 循环,但我需要它尽可能快(这是一个瓶颈),因此欢迎使用本机 numpy 解决方案。
为了进一步说明,我有这个:
a = numpy.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], numpy.int16)
start = numpy.array([1, 5, 7], numpy.int16)
end = numpy.array([2, 10, 9], numpy.int16)
并且需要以某种方式使它变成这样:
[[1], [5, 6, 7, 8, 9], [7, 8]]
【问题讨论】:
我很难理解start
和 end
与此有什么关系。顺便说一句,我认为你不能完全在 numpy 中做到这一点,因为 numpy 数组需要是矩形的。
你可能会尝试将起始值作为列表中的元组。
由于这里似乎没有规范的 numpy 解决方案,如果您需要更多想法,您可能想要添加您之后实际使用它执行的操作,并且如果切片具有一些特殊属性。
【参考方案1】:
这可以(几乎?)在纯 numpy
中使用掩码数组和步幅技巧来完成。首先,我们创建我们的面具:
>>> indices = numpy.arange(a.size)
>>> mask = ~((indices >= start[:,None]) & (indices < end[:,None]))
或者更简单地说:
>>> mask = (indices < start[:,None]) | (indices >= end[:,None])
对于那些以>=
为起始值和<
为结束值的索引,掩码为False
(即未掩码的值)。 (使用None
(又名numpy.newaxis
)进行切片添加了一个新维度,可以进行广播。)现在我们的掩码如下所示:
>>> mask
array([[ True, False, True, True, True, True, True, True, True,
True, True, True],
[ True, True, True, True, True, False, False, False, False,
False, True, True],
[ True, True, True, True, True, True, True, False, False,
True, True, True]], dtype=bool)
现在我们必须使用stride_tricks
拉伸数组以适应掩码:
>>> as_strided = numpy.lib.stride_tricks.as_strided
>>> strided = as_strided(a, mask.shape, (0, a.strides[0]))
>>> strided
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], dtype=int16)
这看起来像一个 3x12 数组,但每一行都指向同一个内存。现在我们可以将它们组合成一个掩码数组:
>>> numpy.ma.array(strided, mask=mask)
masked_array(data =
[[-- 1 -- -- -- -- -- -- -- -- -- --]
[-- -- -- -- -- 5 6 7 8 9 -- --]
[-- -- -- -- -- -- -- 7 8 -- -- --]],
mask =
[[ True False True True True True True True True True True True]
[ True True True True True False False False False False True True]
[ True True True True True True True False False True True True]],
fill_value = 999999)
这与您要求的不太一样,但它的行为应该相似。
【讨论】:
好主意,想知道这种方法是否适用于他的用例(在较新的 numpy 版本上)。当前的缺少where
关键字到ufunc
s(1.7 也没有它来减少)。这意味着你的步幅技巧数组将被复制到完整版本中,几乎你在它上面做的任何事情......
嗯,ufunc
中缺少 where
与手头的问题无关,np.ma
通常会避免复制...这不是真的使用np.ma
(本身很酷的想法)的问题困扰着我,它可能不会击败使用循环或列表理解构建幻灯片(仅仅因为数组大小加倍)......不过,它很有趣, +1
@PierreGM,是的,我只是想到了那里的缩减功能,但在某些时候可能需要这些功能......【参考方案2】:
没有 numpy 方法可以做到这一点。请注意,由于它是不规则的,因此无论如何它只会是数组/切片的列表。但是我想补充一点,对于所有(二进制)ufuncs
,它们几乎是 numpy 中的所有函数(或者它们至少基于它们),有 reduceat
方法,它可以帮助您避免实际创建一个切片列表,因此,如果切片很小,也可以加快计算速度:
In [1]: a = np.arange(10)
In [2]: np.add.reduceat(a, [0,4,7]) # add up 0:4, 4:7 and 7:end
Out[2]: array([ 6, 15, 24])
In [3]: np.maximum.reduceat(a, [0,4,7]) # maximum of each of those slices
Out[3]: array([3, 6, 9])
In [4]: w = np.asarray([0,4,7,10]) # 10 for the total length
In [5]: np.add.reduceat(a, w[:-1]).astype(float)/np.diff(w) # equivalent to mean
Out[5]: array([ 1.5, 5. , 8. ])
编辑:由于您的切片重叠,我将补充说这也可以:
# I assume that start is sorted for performance reasons.
reductions = np.column_stack((start, end)).ravel()
sums = np.add.reduceat(a, reductions)[::2]
[::2]
通常在这里应该没什么大不了的,因为没有为重叠切片做真正的额外工作。
stop==len(a)
的切片也存在一个问题。必须避免这种情况。如果你只有一个切片,你可以做reductions = reductions[:-1]
(如果它是最后一个),否则你只需要在a
上附加一个值来欺骗reduceat
:
a = np.concatenate((a, [0]))
在末尾添加一个值并不重要,因为无论如何您都在处理切片。
【讨论】:
【参考方案3】:这不是一个“纯粹”的 numpy 解决方案(尽管正如 @mgilson 的评论所指出的,很难看出不规则输出如何成为一个 numpy 数组),但是:
a = numpy.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], numpy.int16)
start = numpy.array([1, 5, 7], numpy.int16)
end = numpy.array([2, 10, 9], numpy.int16)
map(lambda range: a[range[0]:range[1]],zip(start,end))
让你:
[array([1], dtype=int16), array([5, 6, 7, 8, 9], dtype=int16), array([7, 8], dtype=int16)]
根据需要。
【讨论】:
【参考方案4】:如果你想要它在一行中,它会是:
x=[list(a[s:e]) for (s,e) in zip(start,end)]
【讨论】:
【参考方案5】:类似 timday 的解决方案。类似的速度:
a = np.random.randint(0,20,1e6)
start = np.random.randint(0,20,1e4)
end = np.random.randint(0,20,1e4)
def my_fun(arr,start,end):
return arr[start:end]
%timeit [my_fun(a,i[0],i[1]) for i in zip(start,end)]
%timeit map(lambda range: a[range[0]:range[1]],zip(start,end))
100 loops, best of 3: 7.06 ms per loop
100 loops, best of 3: 6.87 ms per loop
【讨论】:
以上是关于用另一个数组切片 numpy 数组的主要内容,如果未能解决你的问题,请参考以下文章