如何在 Scikit 中计算多类分类的混淆矩阵?

Posted

技术标签:

【中文标题】如何在 Scikit 中计算多类分类的混淆矩阵?【英文标题】:How compute confusion matrix for multiclass classification in Scikit? 【发布时间】:2017-09-25 16:15:17 【问题描述】:

我有一个多类分类任务。当我基于scikit example 运行我的脚本时,如下所示:

classifier = OneVsRestClassifier(GradientBoostingClassifier(n_estimators=70, max_depth=3, learning_rate=.02))

y_pred = classifier.fit(X_train, y_train).predict(X_test)
cnf_matrix = confusion_matrix(y_test, y_pred)

我收到此错误:

File "C:\ProgramData\Anaconda2\lib\site-packages\sklearn\metrics\classification.py", line 242, in confusion_matrix
    raise ValueError("%s is not supported" % y_type)
ValueError: multilabel-indicator is not supported

我尝试将labels=classifier.classes_ 传递给confusion_matrix(),但没有帮助。

y_test 和 y_pred 如下:

y_test =
array([[0, 0, 0, 1, 0, 0],
   [0, 0, 0, 0, 1, 0],
   [0, 1, 0, 0, 0, 0],
   ..., 
   [0, 0, 0, 0, 0, 1],
   [0, 0, 0, 1, 0, 0],
   [0, 0, 0, 0, 1, 0]])


y_pred = 
array([[0, 0, 0, 0, 0, 0],
   [0, 0, 0, 0, 0, 0],
   [0, 0, 0, 0, 0, 0],
   ..., 
   [0, 0, 0, 0, 0, 1],
   [0, 0, 0, 0, 0, 1],
   [0, 0, 0, 0, 0, 0]])

【问题讨论】:

为什么你有 y_predy_test 作为 one-hot 编码数组?你原来的类标签是什么?你应该给出你的代码,从你如何转换你的y开始。 @VivekKumar 我将y_trainy_test 二进制化为y_test = label_binarize(y_test, classes=[0, 1, 2, 3, 4, 5])OneVsRestClassifier() 您应该将原始类(未二进制化)放入confusion_matrix。您需要反向转换您的 y_pred 以从中获取原始类。 @VivekKumar 谢谢。我使用了非二值化版本,它解决了。 【参考方案1】:

这对我有用:

y_test_non_category = [ np.argmax(t) for t in y_test ]
y_predict_non_category = [ np.argmax(t) for t in y_predict ]

from sklearn.metrics import confusion_matrix
conf_mat = confusion_matrix(y_test_non_category, y_predict_non_category)

其中y_testy_predict 是分类变量,例如单热向量。

【讨论】:

【参考方案2】:

首先您需要创建标签输出数组。 假设您有 3 个类: 'cat', 'dog', 'house' 索引: 0,1,2 。 2个样本的预测是:'dog','house'。 您的输出将是:

y_pred = [[0, 1, 0],[0, 0, 1]]

运行 y_pred.argmax(1) 得到:[1,2] 这个数组代表原始标签索引,意思是: ['狗','房子']

num_classes = 3

# from lable to categorial
y_prediction = np.array([1,2]) 
y_categorial = np_utils.to_categorical(y_prediction, num_classes)

# from categorial to lable indexing
y_pred = y_categorial.argmax(1)

【讨论】:

【参考方案3】:

我只是从预测y_pred 矩阵中减去输出y_test 矩阵,同时保持分类格式。对于-1,我假设为假阴性,而对于1,我假设为假阳性。

下一个:

if output_matrix[i,j] == 1 and predictions_matrix[i,j] == 1:  
    produced_matrix[i,j] = 2 

以下列符号结束:

-1:假阴性  1:误报  0:真否定  2:真阳性

最后,进行一些简单的计数可以产生任何混淆指标。

【讨论】:

以上是关于如何在 Scikit 中计算多类分类的混淆矩阵?的主要内容,如果未能解决你的问题,请参考以下文章

为多类多标签分类构建混淆矩阵

如何从多类分类的混淆矩阵中提取假阳性、假阴性

scikit-learn 多分类混淆矩阵

使用 R 在 keras 中为多类分类创建混淆矩阵

Scikit 学习如何打印混淆矩阵的标签?

如何标准化混淆矩阵?