SGD 分类器 Precision-Recall 曲线

Posted

技术标签:

【中文标题】SGD 分类器 Precision-Recall 曲线【英文标题】:SGD classifier Precision-Recall curve 【发布时间】:2021-07-31 06:14:46 【问题描述】:

我正在研究一个二元分类问题,我有一个像这样的 sgd 分类器:

sgd = SGDClassifier(
    max_iter            = 1000, 
    tol                 = 1e-3,
    validation_fraction = 0.2,
    class_weight = 0:0.5, 1:8.99
)

我将它安装在我的训练集上并绘制了精确召回曲线:

from sklearn.metrics import plot_precision_recall_curve
disp = plot_precision_recall_curve(sgd, X_test, y_test)

鉴于 scikit-learn 中的 sgd 分类器默认使用loss="hinge",那么这条曲线怎么可能绘制出来呢?我的理解是 sgd 的输出不是概率的——要么是 1/0。因此没有“阈值”,但是 sklearn 精确召回曲线绘制了具有不同阈值的锯齿形图。这是怎么回事?

【问题讨论】:

【参考方案1】:

您描述的情况实际上与 documentation example 中的情况相同,使用前 2 类虹膜数据和 LinearSVC 分类器(该算法使用平方铰链损失,就像您在此处使用的铰链损失一样,导致分类器只产生二元结果而不是概率结果)。产生的情节是:

即这里的质量和你的差不多。

尽管如此,您的问题是一个合理的问题,而且确实是一个不错的问题;当我们的分类器确实不产生概率预测时(因此任何阈值的概念听起来都无关紧要),我们怎么会得到与概率分类器产生的行为相似的行为?

要了解为什么会这样,我们需要对 scikit-learn 源代码进行一些挖掘,从这里使用的 plot_precision_recall_curve 函数开始,然后沿着线程进入兔子洞......

plot_precision_recall_curve的source code开始,我们发现:

y_pred, pos_label = _get_response(
    X, estimator, response_method, pos_label=pos_label)

因此,为了绘制 PR 曲线,预测 y_pred不是直接由我们分类器的 predict 方法产生,而是由 scikit 的 _get_response() 内部函数产生-学习。

_get_response() 依次包含以下行:

prediction_method = _check_classifier_response_method(
    estimator, response_method)

y_pred = prediction_method(X)

这最终将我们引向_check_classifier_response_method() 内部函数;您可以查看完整的source code - 这里感兴趣的是else 声明之后的以下3 lines:

predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function

现在,您可能已经开始明白这一点:在底层,plot_precision_recall_curve 检查所使用的分类器是否有 predict_proba()decision_function() 方法可用;如果predict_proba() 可用的,就像您在此处的带有铰链损失的SGDClassifier 的情况(或带有平方铰链损失的LinearSVC 分类器的documentation example),它会恢复为decision_function()方法,而是为了计算y_pred,随后将用于绘制 PR(和 ROC)曲线。


以上可以说已经回答了您的编程问题,即在这种情况下 scikit-learn 究竟是如何产生情节和基础计算的;关于是否以及为什么使用非概率分类器的decision_function() 确实是获得 PR(或 ROC)曲线的正确和合法方法的进一步理论调查超出了 SO 的范围,应发送至Cross Validated ,如有必要。

【讨论】:

@nz_21 不客气;这确实是一个合理的问题(现在在 SO 中并不典型),我很惊讶以前没有被问过(我在回答之前搜索了可能的重复项,但找不到任何重复项)。 不错的答案,赞成。我认为现在人们对 sklearn 的研究并不多,所以这很正常,以前没有人问过。

以上是关于SGD 分类器 Precision-Recall 曲线的主要内容,如果未能解决你的问题,请参考以下文章

使用 SGD 分类器和 GridsearchCV 寻找***特征

如果我们在 SGD 分类器上使用校准后的 cv 作为线性核,如何获得特征权重

在 sklearn 中,具有线性内核的 SVM 模型和具有 loss=hinge 的 SGD 分类器有啥区别

在 scikit-learn 库中使用 sgd 求解器的 SGDClassifier 与 LogisticRegression

在 scikit-learn 中使用交叉验证时绘制 Precision-Recall 曲线

SKlearn SGD Partial Fit 错误:特征数 378 与之前的数据不匹配 4598