为啥 scipy 的稀疏 csr_matrix 的向量点积比 numpy 的密集数组慢?
Posted
技术标签:
【中文标题】为啥 scipy 的稀疏 csr_matrix 的向量点积比 numpy 的密集数组慢?【英文标题】:Why is vector dot product slower with scipy's sparse csr_matrix than numpy's dense array?为什么 scipy 的稀疏 csr_matrix 的向量点积比 numpy 的密集数组慢? 【发布时间】:2016-02-22 09:07:52 【问题描述】:我有一种情况,我需要从稀疏矩阵中提取单行,并用密集行获取其点积。使用 scipy 的 csr_matrix,这似乎比使用 numpy 的密集数组乘法要慢得多。这让我感到惊讶,因为我预计稀疏点积将涉及更少的操作。这是一个例子:
import timeit as ti
sparse_setup = 'import numpy as np; import scipy.sparse as si;' + \
'u = si.eye(10000).tocsr()[10];' + \
'v = np.random.randint(100, size=10000)'
dense_setup = 'import numpy as np; u = np.eye(10000)[10];' + \
'v = np.random.randint(100, size=10000)'
ti.timeit('u.dot(v)', setup=sparse_setup, number=100000)
2.788649031019304
ti.timeit('u.dot(v)', setup=dense_setup, number=100000)
2.179030169005273
对于矩阵向量乘法,稀疏表示很容易胜出,但在这种情况下并非如此。我尝试使用 csc_matrix,但性能更差:
>>> sparse_setup = 'import numpy as np; import scipy.sparse as si;' + \
... 'u = si.eye(10000).tocsc()[10];' + \
... 'v = np.random.randint(100, size=10000)'
>>> ti.timeit('u.dot(v)', setup=sparse_setup, number=100000)
7.0045155879925005
为什么在这种情况下 numpy 会击败 scipy.sparse?对于此类计算,是否有更快的矩阵格式?
【问题讨论】:
稀疏矩阵节省了内存,但在计算方面却更加复杂。 您如何计算“操作”?只是乘法和加法?或者您是否正在考虑索引、迭代等。在现代处理器上,将 2 个数字相乘并不是一项昂贵的操作。dot
也专门用于快速数值库。稀疏乘法也被编译,但不是使用相同的优化库。
为了清楚起见,您正在测试一个非常稀疏的向量(10000 个中的 1 个非零)乘以相同大小的密集向量。我认为它最终使用了sparse._sparsetools.csr_matvec
,一个编译函数。我必须查看 scipy github 才能进一步挖掘。
@hpaulj,我认为稀疏矩阵是行列值元组的集合。 k 元素稀疏行向量与 n 元素密集向量的点积应该只需要 O(k) 操作(常数因子高于密集),而密集乘法应该需要 O(n)。我会采纳你的好建议并阅读源代码。
当我改变u
(和密集等效项)中非零值的数量时,时间没有变化。 (u.data * v[u.indices]).sum()
更接近您想象的情况。它有点快,但它的时间仍然不是 O(k)。
【参考方案1】:
CSR/CSC 矢量乘积调用每次调用有几微秒的开销,从执行少量 Python 代码到处理编译代码 (scipy.sparse._sparsetools.csr_matvec) 中的参数。
在现代处理器上,计算矢量点积非常快,因此在这种情况下,开销支配了计算时间。矩阵向量产品本身更昂贵,在这里类似的开销是不可见的。
为什么 Numpy 的开销会更小?这主要是因为对代码进行了更好的优化; csr_matrix 的性能可能会在这里得到提升。
【讨论】:
以上是关于为啥 scipy 的稀疏 csr_matrix 的向量点积比 numpy 的密集数组慢?的主要内容,如果未能解决你的问题,请参考以下文章
python稀疏矩阵得到每列最大k项的值,对list内为类对象的排序(scipy.sparse.csr.csr_matrix)
以可移植数据格式保存/加载 scipy sparse csr_matrix
scipy csr_matrix和csc_matrix函数详解