计算 numpy 数组和 csr_matrix 之间的成对最小值的最有效方法

Posted

技术标签:

【中文标题】计算 numpy 数组和 csr_matrix 之间的成对最小值的最有效方法【英文标题】:Most effective way to compute the pairwise minimum between a numpy array and a csr_matrix 【发布时间】:2021-12-06 20:24:12 【问题描述】:

我有一个形状为(1, 1000) 的numpy 数组V。我还有一个 csr_matrix M,形状为 (100000, 1000)。对于M 中的每一行m,我想计算Vm 之间的成对最小值,并将所有结果存储在一个新矩阵中,并且我想高效地完成它。最终结果也应该是一个形状为(100000, 1000)的矩阵。

我考虑/尝试过的一些方法:

使用 for 循环遍历 M 的每一行。这可行,但速度很慢。 将M 转换为矩阵:numpy.minimum(V, M.toarray()),这会占用大量内存。 numpy.minimum(V, M) 不起作用。我收到一条错误消息:Comparing a sparse matrix with a scalar less than zero using >= is inefficient

在不占用太多内存或时间的情况下,有什么好的和有效的方法来做到这一点?

【问题讨论】:

v 的某些值是否为负数? 我第二个@hpaulj 的问题。如果v 中的值是非负数,那么有一些相当有效的方法。 不,VM 中的所有数字都在 0 和 1 之间,包括 0 和 1。 这是找到元素最大值的有效解决方案。 ***.com/a/64920528/4045774 【参考方案1】:

如果v 中的值是非负数,这是一种简洁的方法,应该比循环遍历每一行要快得多:

import numpy as np
from scipy.sparse import csr_matrix

def rowmin(M, v):
    # M must be a csr_matrix, and v must be a 1-d numpy array with
    # length M.shape[1].  The values in v must be nonnegative.
    if np.any(v < 0):
        raise ValueError('v must not contain negative values.')

    # B is a CSR matrix with the same sparsity pattern as M, but its
    # data values are from v:
    B = csr_matrix((v[M.indices], M.indices, M.indptr))
    return M.minimum(B)

为了在v 中允许负值,此修改有效,但当v 具有负值时会生成警告,因为在将负值复制到其中时必须更改B 中的稀疏模式。 (多行代码可以消除警告。)v 中的许多负值可能会显着降低性能。

def rowmin(M, v):
    # M must be a csr_matrix, and v must be a 1-d numpy array with
    # length M.shape[1].

    # B is a CSR matrix with the same sparsity pattern as M, but its
    # data values are from v:
    B = csr_matrix((v[M.indices], M.indices, M.indptr))

    # If there are negative values in v, include them in B.
    negmask = v < 0
    if np.any(negmask):
        negindices = negmask.nonzero()[0]
        B[:, negindices] = v[negindices]

    return M.minimum(B)

【讨论】:

漂亮的解决方案。它就像一个魅力,它是如此的快。太感谢了。 :)

以上是关于计算 numpy 数组和 csr_matrix 之间的成对最小值的最有效方法的主要内容,如果未能解决你的问题,请参考以下文章

python使用scipy中的sparse.csr_matrix函数将numpy数组转化为稀疏矩阵(Create A Sparse Matrix)

将numpy对象数组转换为稀疏矩阵

数据分析之Numpy-数组计算

数据分析之numpy篇

数据分析之numpy篇

NumPy之:NumPy简介教程