删除 NumPy 数组中具有重复项的行
Posted
技术标签:
【中文标题】删除 NumPy 数组中具有重复项的行【英文标题】:Removing rows with duplicates in a NumPy array 【发布时间】:2011-11-18 07:14:02 【问题描述】:我有一个(N,3)
numpy 值数组:
>>> vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
>>> vals
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 7],
[0, 4, 5],
[2, 2, 1],
[0, 0, 0],
[5, 4, 3]])
我想从数组中删除具有重复值的行。例如,上述数组的结果应该是:
>>> duplicates_removed
array([[1, 2, 3],
[4, 5, 6],
[0, 4, 5],
[5, 4, 3]])
我不确定如何在不循环的情况下使用 numpy 有效地执行此操作(数组可能非常大)。有人知道我该怎么做吗?
【问题讨论】:
“不循环”是什么意思?你必须检查数组中的每一项,所以无论你使用什么技巧来隐藏循环,它都是 O(m*n)。 我认为他的意思是在 Numpy 中循环而不是在 Python 中循环。编译后的 Numpy 函数中的 O(mn) 比 Pythonfor
循环中的 O(mn) 快得多。当编译代码和解释代码的选项时,常量物质。 span>
From your comments
,因为您希望将其概括为处理通用编号。列,您可能会发现 this solution
这个问题值得一读。
【参考方案1】:
这是一种处理通用列数并且仍然是矢量化方法的方法 -
def rows_uniq_elems(a):
a_sorted = np.sort(a,axis=-1)
return a[(a_sorted[...,1:] != a_sorted[...,:-1]).all(-1)]
步骤:
按每一行排序。
查找每行中连续元素之间的差异。因此,任何具有至少一个零差异的行都表示重复元素。我们将使用它来获取有效行的掩码。因此,最后一步是使用掩码从输入数组中简单地选择有效行。
示例运行 -
In [49]: a
Out[49]:
array([[1, 2, 3, 7],
[4, 5, 6, 7],
[7, 8, 7, 8],
[0, 4, 5, 6],
[2, 2, 1, 1],
[0, 0, 0, 3],
[5, 4, 3, 2]])
In [50]: rows_uniq_elems(a)
Out[50]:
array([[1, 2, 3, 7],
[4, 5, 6, 7],
[0, 4, 5, 6],
[5, 4, 3, 2]])
【讨论】:
不感兴趣的 isnp.sort(a)
等同于 a[np.arange(idx.shape[0])[:,None], idx]
?
@EBB 不知道我为什么要采取那种间接的方式。更新了该排序。感谢您的建议!
太好了,谢谢!实际上,就像您上传时一样,我正在再次阅读您的答案!怪异的! ...
和:
在你的切片操作中是一样的吗?我以前没见过这个实现吗?我也很好奇使用axis = -1
和axis = 1
之间是否有区别?对于我的问题,两个操作都返回相同的答案?在您的解决方案中选择axis = -1
是否有特定原因?感谢您的帮助!
@EBB 这只是更通用一点,因为它处理任何通用维度数组以删除行。因此,任何 2D、3D 等数组现在都可以工作。【参考方案2】:
六年过去了,但这个问题帮助了我,所以我比较了 Divakar、Benjamin、Marcelo Cantos 和 Curtis Patrick 给出的答案的速度。
import numpy as np
vals = np.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
def rows_uniq_elems1(a):
idx = a.argsort(1)
a_sorted = a[np.arange(idx.shape[0])[:,None], idx]
return a[(a_sorted[:,1:] != a_sorted[:,:-1]).all(-1)]
def rows_uniq_elems2(a):
a = (a[:,0] == a[:,1]) | (a[:,1] == a[:,2]) | (a[:,0] == a[:,2])
return np.delete(a, np.where(a), axis=0)
def rows_uniq_elems3(a):
return np.array([v for v in a if len(set(v)) == len(v)])
def rows_uniq_elems4(a):
return np.array([v for v in a if len(np.unique(v)) == len(v)])
结果:
%timeit rows_uniq_elems1(vals)
10000 loops, best of 3: 67.9 µs per loop
%timeit rows_uniq_elems2(vals)
10000 loops, best of 3: 156 µs per loop
%timeit rows_uniq_elems3(vals)
1000 loops, best of 3: 59.5 µs per loop
%timeit rows_uniq_elems(vals)
10000 loops, best of 3: 268 µs per loop
似乎使用set
胜过numpy.unique
。就我而言,我需要在更大的数组上执行此操作:
bigvals = np.random.randint(0,10,3000).reshape([3,1000])
%timeit rows_uniq_elems1(bigvals)
10000 loops, best of 3: 276 µs per loop
%timeit rows_uniq_elems2(bigvals)
10000 loops, best of 3: 192 µs per loop
%timeit rows_uniq_elems3(bigvals)
10000 loops, best of 3: 6.5 ms per loop
%timeit rows_uniq_elems4(bigvals)
10000 loops, best of 3: 35.7 ms per loop
没有列表推导的方法要快得多。但是,行数是硬编码的,很难扩展到三列以上,所以在我的情况下,至少使用集合的列表理解是最好的答案。
已编辑,因为我混淆了 bigvals
中的行和列
【讨论】:
【参考方案3】:这是一个选项:
import numpy
vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
a = (vals[:,0] == vals[:,1]) | (vals[:,1] == vals[:,2]) | (vals[:,0] == vals[:,2])
vals = numpy.delete(vals, numpy.where(a), axis=0)
【讨论】:
我正在努力解决这个问题,干得好。但你不需要|不是^? 这比列表理解方法快得多,所以我可能会接受。想知道是否有任何方法可以推广到 NxM? @Ned Batchelder:是的,尽管在这种情况下它不会改变任何东西。 @jterrace 您可以通过生成 0-m 的组合进行概括,在生成器表达式中使用它们进行比较,然后减少|
以得到 a
。【参考方案4】:
与 Marcelo 相同,但我认为使用 numpy.unique()
而不是 set()
可能会完全符合您的目标。
numpy.array([v for v in vals if len(numpy.unique(v)) == len(v)])
【讨论】:
好吧,set
也有同样的意图,但也许numpy.unique
更快?
它实际上似乎要慢得多 - numpy.unique() 为 23 秒,而我的机器上 set() 为 3 秒,有 100 万行【参考方案5】:
numpy.array([v for v in vals if len(set(v)) == len(v)])
请注意,这仍然在幕后循环。你无法避免这一点。但即使是数百万行,它也应该可以正常工作。
【讨论】:
我想出了[item for item in vals if Counter(item).most_common(1)[0][1] is 1]
,但这更好,特别是因为你已经知道len(v)
。但是,您仍然在“循环”,因为您正在遍历数组。
虽然我需要重复项的索引位置,但对于大型数组来说,这实际上速度非常快,所以我喜欢 @Benjamin 的解决方案以上是关于删除 NumPy 数组中具有重复项的行的主要内容,如果未能解决你的问题,请参考以下文章