roc_curve() sklearn 的“ValueError:不支持多标签指示符格式”

Posted

技术标签:

【中文标题】roc_curve() sklearn 的“ValueError:不支持多标签指示符格式”【英文标题】:"ValueError: multilabel-indicator format is not supported" for roc_curve() sklearn 【发布时间】:2021-03-22 00:01:42 【问题描述】:

我正在尝试从 roc_curve() 获取 tpr(true positive rate)fpr(false positive rate),然后是 auc score(),然后可以绘制图表以查看我的模型在多标签(500 个标签) 数据不平衡但出现错误。

我正在计算每个标签预测的概率,以便我可以更改阈值以获得更好的精度、召回率和准确度,并在预测时获得大多数目标标签。

代码:

from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import ClassifierChain
rfc = RandomForestClassifier(n_jobs = -1, random_state =0, class_weight = 'balanced')
clf2 = ClassifierChain(rfc)
clf2.fit(X_train , y_train)
y_pred = clf2.predict_proba(X_test)

y_pred.shape
>> (8125,500)

y_pred[0]
>> array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.01, 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.03, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.5 , 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.05, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.03, 0.04, 0.  ,
        0.  , 0.  , 0.01, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.02, 0.  ,
        0.  , 0.01, 0.  , 0.01, 0.  , 0.28, 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ,
        0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.07, 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.02, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.02, 0.01, 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.03, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.15, 0.  , 0.  , 0.02, 0.  ,
        0.01, 0.  , 0.11, 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.02, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.1 , 0.02, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.02,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  ]])

from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
fpr, tpr, thresholds = roc_curve(y_test,y_pred)

最后一行代码给了我错误。

追溯:

ValueError                                Traceback (most recent call last)

<ipython-input-72-ea45ece64953> in <module>()
      1 from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
----> 2 fpr, tpr, thresholds = roc_curve(y_test,y_pred)

1 frames

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    534     if not (y_type == "binary" or
    535             (y_type == "multiclass" and pos_label is not None)):
--> 536         raise ValueError("0 format is not supported".format(y_type))
    537 
    538     check_consistent_length(y_true, y_score, sample_weight)

ValueError: multilabel-indicator format is not supported

【问题讨论】:

【参考方案1】:

这里的重点是,正如docs for sklearn.metrics.roc_curve() 中所述,

注意:此实现仅限于二进制分类任务。

而您的目标数据(y_trainy_test)是多标签(sklearn.utils.multiclass.type_of_target(y_train)'multilabel-indicator')。

也就是说,评估多标签(或多输出)分类器有不同的方法;一种方法是测量每个单独标签的指标,然后对所有标签进行平均(所谓的宏观平均,但这不是唯一的方法;请参阅here 以获取更多参考)。

在 ROC 曲线的情况下,这意味着通过首先训练 n_classes 二元分类器(OvA 策略)或通过利用固有的多标签分类器来为每个标签/类绘制 ROC 曲线。然后,如here 所示,您还可以计算并绘制宏观平均 ROC 曲线。因此,根据所利用的平均方法的种类,您可以采用不同的方式将此二进制度量扩展到多标签设置。

【讨论】:

感谢您的明确解释...您能建议我选择任何衡量模型效率的指标吗?

以上是关于roc_curve() sklearn 的“ValueError:不支持多标签指示符格式”的主要内容,如果未能解决你的问题,请参考以下文章

roc_curve() sklearn 的“ValueError:不支持多标签指示符格式”

sklearn.metrics.roc_curve解析

sklearn.metrics.roc_curve用法

roc_curve()的用法及用途

roc_curve()的用法及用途

如何在 TensorFlow 中获取 ROC_Curve 和混淆矩阵