Numba 中的稀疏矩阵
Posted
技术标签:
【中文标题】Numba 中的稀疏矩阵【英文标题】:Sparse Matrix in Numba 【发布时间】:2013-10-25 13:21:20 【问题描述】:我希望使用 Numba (http://numba.pydata.org/) 加速我的机器学习算法(用 Python 编写)。请注意,该算法将稀疏矩阵作为其输入数据。在我的纯 Python 实现中,我使用了来自 Scipy 的 csr_matrix 和相关类,但显然它与 Numba 的 JIT 编译器不兼容。
我还创建了自己的自定义类来实现稀疏矩阵(基本上是(索引,值)对列表的列表),但它再次与 Numba 不兼容(即,我收到一些奇怪的错误消息说它不识别扩展类型)
是否有另一种简单的方法来仅使用与 Numba 兼容的 numpy(不求助于 SciPy)来实现稀疏矩阵?任何示例代码将不胜感激。谢谢!
【问题讨论】:
您使用了csr_matrix
的哪些功能?您可以尝试在 numpy 中重现他们的行为,尽管我严重怀疑这通常会导致加速......
我只使用 csr_matrix 来存储我的数据。我需要的只是逐行迭代,然后对于我想要检索索引和值列表的每一行。这就是为什么现在我创建了自己的类,实现为一个简单的列表列表。但是 Numba 的编译器再次无法识别它。
【参考方案1】:
您可以将稀疏矩阵的数据作为纯 numpy 或 python 访问。例如
M=sparse.csr_matrix([[1,0,0],[1,0,1],[1,1,1]])
ML = M.tolil()
for d,r in enumerate(zip(ML.data,ML.rows))
# d,r are lists
dr = np.array([d,r])
print dr
产生:
[[1]
[0]]
[[1 1]
[0 2]]
[[1 1 1]
[0 1 2]]
当然,numba 可以处理使用这些数组的代码,当然前提是它不期望每一行都具有相同大小的数组。
lil
格式存储值 2 个对象 dtype 数组,数据和索引存储列表,按行。
【讨论】:
【参考方案2】:如果您只需遍历 CSR 矩阵的值,则可以将属性数据、indptr 和索引传递给函数而不是 CSR 矩阵对象。
from scipy import sparse
from numba import njit
@njit
def print_csr(A, iA, jA):
for row in range(len(iA)-1):
for i in range(iA[row], iA[row+1]):
print(row, jA[i], A[i])
A = sparse.csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])
print_csr(A.data, A.indptr, A.indices)
【讨论】:
以上是关于Numba 中的稀疏矩阵的主要内容,如果未能解决你的问题,请参考以下文章