SGD 分类器 Precision-Recall 曲线
Posted
技术标签:
【中文标题】SGD 分类器 Precision-Recall 曲线【英文标题】:SGD classifier Precision-Recall curve 【发布时间】:2021-07-31 06:14:46 【问题描述】:我正在研究一个二元分类问题,我有一个像这样的 sgd 分类器:
sgd = SGDClassifier(
max_iter = 1000,
tol = 1e-3,
validation_fraction = 0.2,
class_weight = 0:0.5, 1:8.99
)
我将它安装在我的训练集上并绘制了精确召回曲线:
from sklearn.metrics import plot_precision_recall_curve
disp = plot_precision_recall_curve(sgd, X_test, y_test)
鉴于 scikit-learn 中的 sgd 分类器默认使用loss="hinge"
,那么这条曲线怎么可能绘制出来呢?我的理解是 sgd 的输出不是概率的——要么是 1/0。因此没有“阈值”,但是 sklearn 精确召回曲线绘制了具有不同阈值的锯齿形图。这是怎么回事?
【问题讨论】:
【参考方案1】:您描述的情况实际上与 documentation example 中的情况相同,使用前 2 类虹膜数据和 LinearSVC 分类器(该算法使用平方铰链损失,就像您在此处使用的铰链损失一样,导致分类器只产生二元结果而不是概率结果)。产生的情节是:
即这里的质量和你的差不多。
尽管如此,您的问题是一个合理的问题,而且确实是一个不错的问题;当我们的分类器确实不产生概率预测时(因此任何阈值的概念听起来都无关紧要),我们怎么会得到与概率分类器产生的行为相似的行为?
要了解为什么会这样,我们需要对 scikit-learn 源代码进行一些挖掘,从这里使用的 plot_precision_recall_curve
函数开始,然后沿着线程进入兔子洞......
从plot_precision_recall_curve
的source code开始,我们发现:
y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label)
因此,为了绘制 PR 曲线,预测 y_pred
不是直接由我们分类器的 predict
方法产生,而是由 scikit 的 _get_response()
内部函数产生-学习。
_get_response()
依次包含以下行:
prediction_method = _check_classifier_response_method(
estimator, response_method)
y_pred = prediction_method(X)
这最终将我们引向_check_classifier_response_method()
内部函数;您可以查看完整的source code - 这里感兴趣的是else
声明之后的以下3 lines:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function
现在,您可能已经开始明白这一点:在底层,plot_precision_recall_curve
检查所使用的分类器是否有 predict_proba()
或 decision_function()
方法可用;如果predict_proba()
是不 可用的,就像您在此处的带有铰链损失的SGDClassifier 的情况(或带有平方铰链损失的LinearSVC 分类器的documentation example),它会恢复为decision_function()
方法,而是为了计算y_pred
,随后将用于绘制 PR(和 ROC)曲线。
以上可以说已经回答了您的编程问题,即在这种情况下 scikit-learn 究竟是如何产生情节和基础计算的;关于是否以及为什么使用非概率分类器的decision_function()
确实是获得 PR(或 ROC)曲线的正确和合法方法的进一步理论调查超出了 SO 的范围,应发送至Cross Validated ,如有必要。
【讨论】:
@nz_21 不客气;这确实是一个合理的问题(现在在 SO 中并不典型),我很惊讶以前没有被问过(我在回答之前搜索了可能的重复项,但找不到任何重复项)。 不错的答案,赞成。我认为现在人们对 sklearn 的研究并不多,所以这很正常,以前没有人问过。以上是关于SGD 分类器 Precision-Recall 曲线的主要内容,如果未能解决你的问题,请参考以下文章
使用 SGD 分类器和 GridsearchCV 寻找***特征
如果我们在 SGD 分类器上使用校准后的 cv 作为线性核,如何获得特征权重
在 sklearn 中,具有线性内核的 SVM 模型和具有 loss=hinge 的 SGD 分类器有啥区别
在 scikit-learn 库中使用 sgd 求解器的 SGDClassifier 与 LogisticRegression