如何比较使用 scikit-learn 库 load_svmlight_file 存储的 2 个稀疏矩阵?
Posted
技术标签:
【中文标题】如何比较使用 scikit-learn 库 load_svmlight_file 存储的 2 个稀疏矩阵?【英文标题】:How to compare 2 sparse matrix stored using scikit-learn library load_svmlight_file? 【发布时间】:2014-06-01 05:01:03 【问题描述】:我正在尝试比较测试和训练数据集中存在的特征向量。这些特征向量使用 scikitlearn 库 load_svmlight_file 以稀疏格式存储。两个数据集的特征向量的维度相同。但是,我收到此错误:“具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()。”
为什么会出现此错误? 我该如何解决?
提前致谢!
from sklearn.datasets import load_svmlight_file
pathToTrainData="../train.txt"
pathToTestData="../test.txt"
X_train,Y_train= load_svmlight_file(pathToTrainData);
X_test,Y_test= load_svmlight_file(pathToTestData);
for ele1 in X_train:
for ele2 in X_test:
if(ele1==ele2):
print "same vector"
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-3-c1f145f984a6> in <module>()
7 for ele1 in X_train:
8 for ele2 in X_test:
----> 9 if(ele1==ele2):
10 print "same vector"
/Users/rkasat/anaconda/lib/python2.7/site-packages/scipy/sparse/base.pyc in __bool__(self)
181 return True if self.nnz == 1 else False
182 else:
--> 183 raise ValueError("The truth value of an array with more than one "
184 "element is ambiguous. Use a.any() or a.all().")
185 __nonzero__ = __bool__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().
【问题讨论】:
【参考方案1】:您可以使用此条件来检查两个稀疏数组是否完全相等,而无需对其进行致密化:
if (ele1 - ele2).nnz == 0:
# Matched, do something ...
nnz
属性给出稀疏数组中非零元素的数量。
一些简单的测试运行来显示差异:
import numpy as np
from scipy import sparse
A = sparse.rand(10, 1000000).tocsr()
def benchmark1(A):
for s1 in A:
for s2 in A:
if (s1 - s2).nnz == 0:
pass
def benchmark2(A):
for s1 in A:
for s2 in A:
if (s1.toarray() == s2).all() == 0:
pass
%timeit benchmark1(A)
%timeit benchmark2(A)
一些结果:
# Computer 1
10 loops, best of 3: 36.9 ms per loop # with nnz
1 loops, best of 3: 734 ms per loop # with toarray
# Computer 2
10 loops, best of 3: 28 ms per loop
1 loops, best of 3: 312 ms per loop
【讨论】:
不错,绝对更优雅!尽管复制您的代码会使我的笔记本电脑上的基准测试不那么严格:In [6]: %timeit benchmark1(A) 1 loops, best of 3: 7.76 s per loop In [7]: %timeit benchmark2(A) 1 loops, best of 3: 15.4 s per loop
:)(再说一次,随机种子没有修复 :))
虽然当时我的电脑一定出了什么问题。我必须重新检查一下
奇怪的是,即使在重新启动后,基准测试在我的笔记本电脑上仍然如此糟糕。这很有趣,我会调查它。同时我在回答中写了一个新的比较函数来参加比赛:)
我无法找到我的问题。循环在我的笔记本电脑上分别以 7 秒和 15 秒运行(相当新的戴尔,8G 内存,4 核,anaconda,所以没有硬件问题先验)。有人知道为什么会这样吗? (我写的比较函数每次循环运行 13ms,所以发生了一些奇怪的事情)
好的,大电脑说In [6]: %timeit benchmark1(A) 10 loops, best of 3: 49.8 ms per loop In [7]: %timeit benchmark2(A) 1 loops, best of 3: 1.06 s per loop
这样效果更好。不过,我的笔记本电脑仍然一无所知。对于它的价值,我建议的比较功能给In [10]: %timeit benchmark3(A) 100 loops, best of 3: 10.2 ms per loop
【参考方案2】:
如果您的数组很密集,您可能会遇到同样的问题,并且解决方案很简单。替换
if(ele1==ele2):
与
if (ele1 == ele2).all():
但是,由于您使用的是稀疏矩阵,所以这个问题实际上并不是那么容易。值得注意的是,函数all
和any
没有实现稀疏矩阵(至少对于all
是可以理解的,因为all
只能返回True
,如果测试的矩阵被密集填充评估为True
)。
在您的情况下,由于您只是比较稀疏矩阵的行,您可能会发现对它们进行致密化然后进行比较是可以接受的。尝试用
替换提到的行if (ele1.toarray() == ele2).all(): # Densifying one of them casts the other to dense too
一般来说,您似乎想要比较 2 个矩阵的行。根据条目的数量,这可以通过定义矢量化比较函数来更有效地完成,如下所示:
def compare(A, B):
return zip(*np.where((np.array(A.multiply(A).sum(1)) +
np.array(B.multiply(B).sum(1)).T) - 2 * A.dot(B.T).toarray() == 0))
此函数将返回一对索引列表,告诉您哪些行相互对应,并且比代码中使用的双 for 循环效率高得多。
解释:函数compare
使用二项式公式(a - b) ** 2 == a ** 2 + b ** 2 - 2 * a * b
计算两两欧式距离。这个公式也适用于 l2 范数和标量产品。如果矩阵不是稀疏的,公式会变得更简单:squared_distances = (A ** 2).sum(axis=1) + (B ** 2).sum(axis=1) - 2 * A.dot(B.T)
。然后我们使用np.where
检查这些条目中的哪些等于0,并将它们作为元组返回。
以此为基准,我们得到:
import numpy as np
from scipy import sparse
rng = np.random.RandomState(42)
A = sparse.rand(10, 1000000, random_state=rng).tocsr()
In [12]: %timeit compare(A, A)
100 loops, best of 3: 10.2 ms per loop
【讨论】:
以上是关于如何比较使用 scikit-learn 库 load_svmlight_file 存储的 2 个稀疏矩阵?的主要内容,如果未能解决你的问题,请参考以下文章