在numpy数组中找到最大上三角条目索引的有效方法?

Posted

技术标签:

【中文标题】在numpy数组中找到最大上三角条目索引的有效方法?【英文标题】:Efficient way to find the index of the max upper triangular entry in a numpy array? 【发布时间】:2013-09-01 18:24:35 【问题描述】:

更具体地说,我有一个在选择最大条目时需要忽略的行/列列表。换句话说,在选择最大上三角条目时,需要跳过某些索引。在这种情况下,找到最大上三角入口位置的最有效方法是什么?

例如:

>>> a
array([[0, 1, 1, 1],
       [1, 2, 3, 4],
       [4, 5, 6, 6],
       [4, 5, 6, 7]])
>>> indices_to_skip = [0,1,2]

我需要在除a[0,1]a[0,2]a[1,2] 之外的所有元素中找到最小元素的索引。

【问题讨论】:

你能举个例子吗? 刚刚编辑了问题。谢谢! 那么,在这个例子中,你是在取最大的 [1,4,6] 吗?是否包括对角线? 是的,只是 [1,4,6] 并且不包括对角线。 你试过什么?既然您正在寻找最有效的方法,那么您应该已经尝试过一种效率不够高的方法。那么您尝试了什么,为什么没有成功? 【参考方案1】:

你可以使用np.triu_indices_from:

>>> np.vstack(np.triu_indices_from(a,k=1)).T
array([[0, 1],
       [0, 2],
       [0, 3],
       [1, 2],
       [1, 3],
       [2, 3]])

>>> inds=inds[inds[:,1]>2] #Or whatever columns you want to start from.
>>> inds
array([[0, 3],
       [1, 3],
       [2, 3]])


>>> a[inds[:,0],inds[:,1]]
array([1, 4, 6])

>>> max_index = np.argmax(a[inds[:,0],inds[:,1]])
>>> inds[max_index]
array([2, 3]])

或者:

>>> inds=np.triu_indices_from(a,k=1)
>>> mask = (inds[1]>2) #Again change 2 for whatever columns you want to start at.
>>> a[inds][mask]
array([1, 4, 6])

>>> max_index = np.argmax(a[inds][mask])
>>> inds[mask][max_index]
array([2, 3]])

对于上述情况,您可以使用inds[0] 跳过某些行。

跳过特定的行或列:

def ignore_upper(arr, k=0, skip_rows=None, skip_cols=None):
    rows, cols = np.triu_indices_from(arr, k=k)

    if skip_rows != None:
        row_mask = ~np.in1d(rows, skip_rows)
        rows = rows[row_mask]
        cols = cols[row_mask]

    if skip_cols != None:
        col_mask = ~np.in1d(cols, skip_cols)
        rows = rows[col_mask]
        cols = cols[col_mask]

    inds=np.ravel_multi_index((rows,cols),arr.shape)
    return np.take(arr,inds)

print ignore_upper(a, skip_rows=1, skip_cols=2) #Will also take numpy arrays for skipping.
[0 1 1 6 7]

两者可以结合起来,创造性地使用布尔索引可以帮助加快特定情况。

我遇到了一些有趣的事情,一种更快的获取上 triu 索引的方法:

def fast_triu_indices(dim,k=0):

    tmp_range = np.arange(dim-k)
    rows = np.repeat(tmp_range,(tmp_range+1)[::-1])

    cols = np.ones(rows.shape[0],dtype=np.int)
    inds = np.cumsum(tmp_range[1:][::-1]+1)

    np.put(cols,inds,np.arange(dim*-1+2+k,1))
    cols[0] = k
    np.cumsum(cols,out=cols)
    return (rows,cols)

虽然它不适用于k<0,但它的速度大约快了 6 倍:

dim=5000
a=np.random.rand(dim,dim)

k=50
t=time.time()
rows,cols=np.triu_indices(dim,k=k)
print time.time()-t
0.913508892059

t=time.time()
rows2,cols2,=fast_triu_indices(dim,k=k)
print time.time()-t
0.16515994072

print np.allclose(rows,rows2)
True

print np.allclose(cols,cols2)
True

【讨论】:

这看起来不错,但是如何找到最大元素的索引,而不仅仅是最大元素本身? @methane 我更新了前两个示例。你应该可以从那里拿走它。

以上是关于在numpy数组中找到最大上三角条目索引的有效方法?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 numpy 数组中找到行索引的最大值?

如何在numpy中找到每行中最大的索引,行的串联?

我需要一个 numpy 数组中的 N 个最小(索引)值

提高在具有许多条目设置为 np.inf 的二维 numpy 数组中查找最小元素的速度

numpy数组上的有效循环

在numpy数组上查找所有最大值