scikit-learn 中聚类的混淆矩阵
Posted
技术标签:
【中文标题】scikit-learn 中聚类的混淆矩阵【英文标题】:Confusion matrix for Clustering in scikit-learn 【发布时间】:2018-05-22 09:00:19 【问题描述】:我有一组带有已知标签的数据。我想尝试聚类,看看是否可以获得已知标签给出的相同聚类。为了测量准确性,我需要得到一个混淆矩阵之类的东西。
我知道对于分类问题的测试集,我可以轻松获得混淆矩阵。我已经尝试过this。
但是,它不能用于聚类,因为它期望列和行都具有相同的标签集,这对于分类问题很有意义。但是对于聚类问题,我期望的是这样的。
行 - 实际标签
列 - 新的集群名称(即 cluster-1、cluster-2 等)
有没有办法做到这一点?
编辑:这里有更多细节。
在sklearn.metrics.confusion_matrix 中,它期望y_test
和y_pred
具有相同的值,而labels
是这些值的标签。
这就是为什么它给出了一个矩阵,它对这样的行和列都有相同的标签。
但在我的情况下(KMeans 聚类),实际值是字符串,估计值是数字(即簇号)
因此,如果我打电话给confusion_matrix(y_true, y_pred)
,它会给出以下错误。
ValueError: Mix of label input types (string and number)
这是真正的问题。对于分类问题,这是有道理的。但是对于聚类问题,这个限制不应该存在,因为真实的标签名称和新的聚类名称不需要相同。
有了这个,我知道我正在尝试使用一个应该用于分类问题的工具来解决聚类问题。所以,我的问题是,有没有一种方法可以为可能的聚类数据获得这样的矩阵。
希望问题现在更清楚了。如果不是,请告诉我。
【问题讨论】:
请举例说明这一点 添加了更多细节。谢谢。 除非您知道如何将集群编号映射到实际结果,否则您将如何进行? 映射部分正是我想要学习的。我只想知道是否可以映射真实标签和自然簇数。如果我可以在列中获得真实标签并在行中获得集群名称(反之亦然),我可以自己完成。如果我以 Iris 数据集为例,基本上我想知道的是,我的每个新集群中有多少个 setosas、多少个 virginica 等。你明白我在找什么吗? 查看clustering performance evaluation in scikit-learn documentation 上的章节(例如,调整后的兰德指数、标准化/调整后的互信息、V-measure)。 【参考方案1】:您可以轻松计算成对交集矩阵。
但如果 sklearn 库已针对分类用例进行了优化,则可能需要自己执行此操作。
【讨论】:
谢谢,在自己编写之前,我只是想看看是否有 OOTB 方法可以做到这一点。 确实存在这样的实现。例如,在图表上,您通常具有相似性而不是距离。但是在某些时候,自己编写这些东西会变得更容易,而不是花太多时间将不同的库粘合在一起,然后一次被他们所有的错误所困扰。 这是我自己写的,并作为单独的答案发布。【参考方案2】:我自己写了一个代码。
# Compute confusion matrix
def confusion_matrix(act_labels, pred_labels):
uniqueLabels = list(set(act_labels))
clusters = list(set(pred_labels))
cm = [[0 for i in range(len(clusters))] for i in range(len(uniqueLabels))]
for i, act_label in enumerate(uniqueLabels):
for j, pred_label in enumerate(pred_labels):
if act_labels[j] == act_label:
cm[i][pred_label] = cm[i][pred_label] + 1
return cm
# Example
labels=['a','b','c',
'a','b','c',
'a','b','c',
'a','b','c']
pred=[ 1,1,2,
0,1,2,
1,1,1,
0,1,2]
cnf_matrix = confusion_matrix(labels, pred)
print('\n'.join([''.join([':4'.format(item) for item in row])
for row in cnf_matrix]))
编辑: (Dayyyuumm) 刚刚发现我可以使用Pandas Crosstab 轻松做到这一点:-/。
labels=['a','b','c',
'a','b','c',
'a','b','c',
'a','b','c']
pred=[ 1,1,2,
0,1,2,
1,1,1,
0,1,2]
# Create a DataFrame with labels and varieties as columns: df
df = pd.DataFrame('Labels': labels, 'Clusters': pred)
# Create crosstab: ct
ct = pd.crosstab(df['Labels'], df['Clusters'])
# Display ct
print(ct)
【讨论】:
使用 numpy 对代码进行矢量化,使其速度提高 10 倍。以上是关于scikit-learn 中聚类的混淆矩阵的主要内容,如果未能解决你的问题,请参考以下文章