sklearn 多类 roc auc 分数

Posted

技术标签:

【中文标题】sklearn 多类 roc auc 分数【英文标题】:sklearn multiclass roc auc score 【发布时间】:2020-11-27 21:39:46 【问题描述】:

如何获取sklearn中多类分类的roc auc分数?

二进制

# this works
roc_auc_score([0,1,1], [1,1,1])

多类

# this fails
from sklearn.metrics import roc_auc_score

ytest  = [0,1,2,3,2,2,1,0,1]
ypreds = [1,2,1,3,2,2,0,1,1]

roc_auc_score(ytest, ypreds,average='macro',multi_class='ovo')

# AxisError: axis 1 is out of bounds for array of dimension 1

我查看了官方documentation但无法解决问题。

【问题讨论】:

按照文档,您有一个multi_class 参数 【参考方案1】:

多标签案例中的 roc_auc_score 期望具有形状 (n_samples, n_classes) 的二进制标签指示符,这是回到一对多时尚的方式。

要轻松做到这一点,您可以使用 label_binarize (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.label_binarize.html#sklearn.preprocessing.label_binarize)。

对于您的代码,它将是:

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize

# You need the labels to binarize
labels = [0, 1, 2, 3]

ytest  = [0,1,2,3,2,2,1,0,1]

# Binarize ytest with shape (n_samples, n_classes)
ytest = label_binarize(ytest, classes=labels)

ypreds = [1,2,1,3,2,2,0,1,1]

# Binarize ypreds with shape (n_samples, n_classes)
ypreds = label_binarize(ypreds, classes=labels)


roc_auc_score(ytest, ypreds,average='macro',multi_class='ovo')

通常,这里 ypreds 和 yest 变成:

ytest
array([[1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       [0, 1, 0, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0]])

ypreds
array([[0, 1, 0, 0],
       [0, 0, 1, 0],
       [0, 1, 0, 0],
       [0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 0, 0]])

【讨论】:

我使用的是 Python 3,我在上面运行了您的代码并收到以下错误:TypeError: roc_auc_score() got an unexpected keyword argument 'multi_class'。所以我更新到 scikit-learn 0.23.2(有 0.23.1)。感谢您的帖子。

以上是关于sklearn 多类 roc auc 分数的主要内容,如果未能解决你的问题,请参考以下文章

如何在 sklearn 的交叉验证中获得多类 roc auc?

ValueError:使用 sklearn roc_auc_score 函数不支持多类多输出格式

VotingClassifier 中的 roc_auc,scikit-learn (sklearn) 中的 RandomForestClassifier

不同的结果 roc_auc_score 和 plot_roc_curve

sklearn.metrics中的评估方法介绍(accuracy_score, recall_score, roc_curve, roc_auc_score, confusion_matrix,cla

sklearn学习:为什么roc_auc_score()和auc()有不同的结果?