Python - cdist 函数中数组的维度问题

Posted

技术标签:

【中文标题】Python - cdist 函数中数组的维度问题【英文标题】:Python - Issue with the dimension of array in cdist function 【发布时间】:2019-05-20 21:30:37 【问题描述】:

我正在尝试为 k-means 找到正确的簇数,并为此使用 cdist 函数。

我可以理解 cdist 的论点应该是相同的维度。我尝试打印两个参数的大小,即 (2542, 39) 和 (1, 39)。

有人可以建议我哪里出错了吗?

print(tfidf_matrix.shape) ### Returning --> (2542, 39)
#Finding optimal no. of clusters
from scipy.spatial.distance import cdist
clusters=range(1,10)
meanDistortions=[]

for k in clusters:
    model=KMeans(n_clusters=k)
    model.fit(tfidf_matrix)
    prediction=model.predict(tfidf_matrix)
    print(model.cluster_centers_.shape)  ## Returning (1, 39)
    meanDistortions.append(sum(np.min(cdist(tfidf_matrix, model.cluster_centers_, 'euclidean'), axis=1)) /
                           tfidf_matrix.shape[0])

错误:

ValueError                                Traceback (most recent call last)
<ipython-input-181-c15e32d863d2> in <module>()
     12     prediction=model.predict(tfidf_matrix)
     13     print(model.cluster_centers_.shape)
---> 14     meanDistortions.append(sum(np.min(cdist(tfidf_matrix, model.cluster_centers_, 'euclidean'), axis=1)) /
     15                            tfidf_matrix.shape[0])
     16 

~\Downloads\Conda\envs\data-science\lib\site-packages\scipy\spatial\distance.py in cdist(XA, XB, metric, *args, **kwargs)
   2588 
   2589     if len(s) != 2:
-> 2590         raise ValueError('XA must be a 2-dimensional array.')
   2591     if len(sB) != 2:
   2592         raise ValueError('XB must be a 2-dimensional array.')

ValueError: XA must be a 2-dimensional array.

【问题讨论】:

使用tfidf_matrix = np.random.randn(2542, 39) 并导入sklearn.cluster.KMeans,您的代码对我来说运行良好。 你在meanDistortions.append()中缺少, @ScottWarchal 哦,是的,带有 (2542, 39) 的随机 np 数组对我来说也可以正常工作。如果尺寸不是问题,您能建议可能会出现什么问题吗? 【参考方案1】:

这可能是一个类型问题。

Tfidf 可能不是 cdist 要求的 dense 矩阵。当然在这里使用稀疏矩阵是有意义的。

但是,cdist 似乎不接受稀疏矩阵:scipy cdist with sparse matrices

【讨论】:

以上是关于Python - cdist 函数中数组的维度问题的主要内容,如果未能解决你的问题,请参考以下文章

Numpy数组维度

使用python中的内置函数查找3d距离

python如何减小维度

练习 - 使用维度数组和函数查找密码

numpy数组的堆叠:numpy.stack, numpy.hstack, numpy.vstack

Python之数组拼接,组合,连接