如何在 sklearn 的交叉验证中获得多类 roc auc?

Posted

技术标签:

【中文标题】如何在 sklearn 的交叉验证中获得多类 roc auc?【英文标题】:How to get multi-class roc_auc in cross validate in sklearn? 【发布时间】:2020-07-04 19:56:39 【问题描述】:

我有一个分类问题,我想在 sklearn 中使用 cross_validate 获取 roc_auc 值。我的代码如下。

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(random_state = 0, class_weight="balanced")

from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=10, scoring = ('accuracy', 'roc_auc'))

但是,我收到以下错误。

ValueError: multiclass format is not supported

请注意,我选择roc_auc具体是因为它支持binarymulticlass分类,如:https://scikit-learn.org/stable/modules/model_evaluation.html中所述

我也有二进制分类数据集。请告诉我如何解决此错误。

如果需要,我很乐意提供更多详细信息。

【问题讨论】:

【参考方案1】:

默认情况下multi_class='raise',因此您需要明确更改此设置。

来自docs:

multi_class 'raise', 'ovr', 'ovo', default='raise'

仅限多类。确定要使用的配置类型。这 默认值会引发错误,因此必须传递“ovr”或“ovo” 明确的。

'ovr':

计算每个类与其余类的 AUC [3] [4]。这对待 多类情况与多标签情况相同。敏感的 即使在average == 'macro' 时也会出现类不平衡,因为类 不平衡会影响每个“其余”分组的组成。

'ovo':

计算所有可能的成对组合的平均 AUC 类[5]。 average == 'macro'时对类不平衡不敏感。


解决方案:

使用make_scorer (docs):

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(random_state = 0, class_weight="balanced")

from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score

myscore = make_scorer(roc_auc_score, multi_class='ovo',needs_proba=True)

from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=10, scoring = myscore)

【讨论】:

这给出了AxisError: axis 1 is out of bounds for array of dimension 1 in cross_validate。您需要在myscore 的定义中添加needs_proba=True。此外,最好先打乱数据。 @makis 非常感谢您的回答。但是,我收到以下错误TypeError: roc_auc_score() got an unexpected keyword argument 'multi_class'。有没有办法解决这个问题? :) @makis 还有一个问题。如果我想将它用于二进制分类,我应该做些什么改变?谢谢你:) 在 sklearn 0.22.2 中,函数 roc_auc_score 有这个参数。确保升级您的软件包。见:scikit-learn.org/stable/modules/generated/… @makis 非常感谢。如果我要使用您的代码进行二进制分类,如果我在没有multi_class 参数的情况下制作记分器是否正确?即myscore = make_scorer(roc_auc_score, needs_proba=True)。期待您的来信:)

以上是关于如何在 sklearn 的交叉验证中获得多类 roc auc?的主要内容,如果未能解决你的问题,请参考以下文章

使用 sklearn 嵌套交叉验证获得最佳参数

在 sklearn 中使用网格搜索和管道获得正确的交叉验证分数

有没有办法使用 SKlearn 获得滑动嵌套交叉验证?

如何对多类数据进行交叉验证?

SKlearn中具有嵌套交叉验证的分类报告(平均值/个体值)

如何在 sklearn 中编写自定义估算器并对其使用交叉验证?