如何找到重新排序的 numpy 数组的索引?
Posted
技术标签:
【中文标题】如何找到重新排序的 numpy 数组的索引?【英文标题】:How to find indices of a reordered numpy array? 【发布时间】:2017-07-03 02:52:42 【问题描述】:假设我有一个排序的 numpy 数组:
arr = np.array([0.0, 0.0],
[0.5, 0.0],
[1.0, 0.0],
[0.0, 0.5],
[0.5, 0.5],
[1.0, 0.5],
[0.0, 1.0],
[0.5, 1.0],
[1.0, 1.0])
并假设我对其进行了非平凡的操作,这样我就有了一个与旧数组相同但顺序不同的新数组:
arr2 = np.array([0.5, 0.0],
[0.0, 0.0],
[0.0, 0.5],
[1.0, 0.0],
[0.5, 0.5],
[1.0, 0.5],
[0.0, 1.0],
[1.0, 1.0],
[0.5, 1.0])
问题是:如何获得arr2
的每个元素在arr
中的位置的索引。换句话说,我想要一个方法,它接受两个数组并返回一个与arr2
相同长度但元素索引为arr
的数组。例如,返回数组的第一个元素是arr2
在arr
中的第一个元素的索引。
where_things_are(arr2, arr)
return : array([1, 0, 3, 2, 4, 5, 6, 8, 7])
这样的函数是否已经存在于 numpy 中?
编辑:
我试过了:
np.array([np.where((arr == x).all(axis=1)) for x in arr2])
返回我想要的,但我的问题仍然存在:有没有更有效的方法来使用 numpy 方法?
EDIT2:
如果arr2
的长度与原始数组的长度不同(比如我从中删除了一些元素),它也应该起作用。因此,它不是查找和反转排列,而是查找元素所在的位置。
【问题讨论】:
“逆”不会是唯一的——通过添加索引轴来增加原始 arr 会更好,通过“非平凡操作”进行处理 我使用的非平凡操作将保留唯一性是的,但保留原始索引无济于事,因为该操作不会保留顺序。 对添加的索引轴也进行相同的重新排序操作,之后索引仍然标记arr转换后元素的原始位置,便于在添加的索引轴上排序以恢复原始顺序 “如果 arr2 的长度与原始数组的长度不同,它也应该可以工作”-停止改变我们的问题。 【参考方案1】:关键是反转排列。即使原始数组未排序,下面的代码也可以工作。如果已排序,则可以使用find_map_sorted
,这显然更快。
更新:为了适应 OP 不断变化的需求,我添加了一个处理丢失元素的分支。
import numpy as np
def invperm(p):
q = np.empty_like(p)
q[p] = np.arange(len(p))
return q
def find_map(arr1, arr2):
o1 = np.argsort(arr1)
o2 = np.argsort(arr2)
return o2[invperm(o1)]
def find_map_2d(arr1, arr2):
o1 = np.lexsort(arr1.T)
o2 = np.lexsort(arr2.T)
return o2[invperm(o1)]
def find_map_sorted(arr1, arrs=None):
if arrs is None:
o1 = np.lexsort(arr1.T)
return invperm(o1)
# make unique-able
rdtype = np.rec.fromrecords(arrs[:1, ::-1]).dtype
recstack = np.r_[arrs[:,::-1], arr1[:,::-1]].view(rdtype).view(np.recarray)
uniq, inverse = np.unique(recstack, return_inverse=True)
return inverse[len(arrs):]
x1 = np.random.permutation(100000)
x2 = np.random.permutation(100000)
print(np.all(x2[find_map(x1, x2)] == x1))
rows = np.random.random((100000, 8))
r1 = rows[x1, :]
r2 = rows[x2, :]
print(np.all(r2[find_map_2d(r1, r2)] == r1))
rs = r1[np.lexsort(r1.T), :]
print(np.all(rs[find_map_sorted(r2), :] == r2))
# lose ten elements
print(np.all(rs[find_map_sorted(r2[:-10], rs), :] == r2[:-10]))
【讨论】:
【参考方案2】:如果你保证唯一性:
[ np.where(np.logical_and((arr2==x)[:,1], (arr2==x)[:,0])==True)[0][0] for x in arr]
请注意,我将您的数组转换为 2D: 例如
arr2 = np.array([[0.5, 0.0],
[0.0, 0.0],
[0.0, 0.5],
[1.0, 0.0],
[0.5, 0.5],
[1.0, 0.5],
[0.0, 1.0],
[1.0, 1.0],
[0.5, 1.0]])
【讨论】:
【参考方案3】:这是一种使用 numpy Broadcasting 的方法:
In [10]: ind = np.where(arr[:, None] == arr2[None, :])[1]
In [11]: ind[np.where(np.diff(ind)==0)]
Out[11]: array([1, 0, 3, 2, 4, 5, 6, 8, 7])
这背后的想法是,增加数组的维度,以便它们的比较产生一个 3d 数组,因为原始子数组的长度为 2,如果我们在比较结果的第二个轴上有两个连续相等的项目,它们将是其中两个项目相等。为了更好地演示,这里是没有选择第二个轴的比较结果:
In [96]: np.where(arr[:, None] == arr2[None, :])
Out[96]:
(array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3,
3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 8, 8, 8, 8, 8, 8]),
array([0, 1, 1, 2, 3, 6, 0, 0, 1, 3, 4, 8, 0, 1, 3, 3, 5, 7, 1, 2, 2, 4, 5,
6, 0, 2, 4, 4, 5, 8, 2, 3, 4, 5, 5, 7, 1, 2, 6, 6, 7, 8, 0, 4, 6, 7,
8, 8, 3, 5, 6, 7, 7, 8]),
array([1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1,
0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1,
0, 1, 0, 0, 1, 0, 1, 1]))
然后为了找到这些项目,我们只需要找到它们的差异为 0 的地方。
【讨论】:
【参考方案4】:numpy_indexed 包(免责声明:我是它的作者)包含针对此类问题的有效功能; npi.indices 是 list.index 的 ndarray 等价物。
import numpy_indexed as npi
idx = npi.indices(arr, arr2)
这将返回一个索引列表,例如 arr[idx] == arr2。如果 arr2 包含 arr 中不存在的元素,则会引发 ValueError;但是您可以使用“缺失”的 kwarg 来控制它。
如果此功能包含在 numpy 中,请回答您的问题;是的,从某种意义上说,numpy 是一个图灵完备的生态系统。但并非如此,如果您计算以高效、正确和通用的方式实现此功能所需的代码行数。
【讨论】:
看起来是一个有趣的扩展。您介意 - 非常简短地 - 描述您正在使用的算法吗?谢谢! 它类似于此处描述的其他基于 arg-sorting 的方法,并且在性能上也应该相似。额外的代码行主要是为了覆盖边缘情况并使其更通用(比如处理 ndarray、在任意轴上获取索引、有趣的 dtype 等等)以上是关于如何找到重新排序的 numpy 数组的索引?的主要内容,如果未能解决你的问题,请参考以下文章
python 按二维数组的某行或列排序 (numpy lexsort)
python使用np.argsort对一维numpy概率值数据排序获取倒序索引获取的top索引(例如top2top5top10)索引二维numpy数组中对应的原始数据:原始数据概率最大的头部数据
python使用np.argsort对一维numpy概率值数据排序获取升序索引获取的top索引(例如top2top5top10)索引二维numpy数组中对应的原始数据:原始数据概率最小的头部数据
Python计算两个numpy数组的交集(Intersection)实战:两个输入数组的交集并排序获取交集元素及其索引如果输入数组不是一维的,它们将被展平(flatten),然后计算交集