scikit-learn 的 svm 的 predict_proba 的混淆概率

Posted

技术标签:

【中文标题】scikit-learn 的 svm 的 predict_proba 的混淆概率【英文标题】:Confusing probabilities of the predict_proba of scikit-learn's svm 【发布时间】:2015-08-20 20:30:44 【问题描述】:

我的目的是根据特定类别的每个样本的排序概率来绘制 PR 曲线。然而,我发现当我使用两个不同的标准数据集时,svm 的 predict_proba() 获得的概率有两种不同的行为:虹膜和数字。

第一个案例是用下面的python代码用“iris”案例评估的,它合理地工作,该类获得最高的概率。

D = datasets.load_iris()
clf = SVC(kernel=chi2_kernel, probability=True).fit(D.data, D.target)
output_predict = clf.predict(D.data)
output_proba = clf.predict_proba(D.data)
output_decision_function = clf.decision_function(D.data)
output_my = proba_to_class(output_proba, clf.classes_)

print D.data.shape, D.target.shape
print "target:", D.target[:2]
print "class:", clf.classes_
print "output_predict:", output_predict[:2]
print "output_proba:", output_proba[:2]

接下来,它会产生如下输出。显然,每个样本的最高概率与 predict() 的输出相匹配:样本 #1 为 0.97181088,样本 #2 为 0.96961523。

(150, 4) (150,)
target: [0 0]
class: [0 1 2]
output_predict: [0 0]
output_proba: [[ 0.97181088  0.01558693  0.01260218]
[ 0.96961523  0.01702481  0.01335995]]

但是,当我使用以下代码将数据集更改为“数字”时,概率揭示了一种相反的现象,即每个样本的最低概率主导 predict() 的输出标签,样本 #1 的概率为 0.00190932 和样品 #2 为 0.00220549。

D = datasets.load_digits()

输出:

(1797, 64) (1797,)
target: [0 1]
class: [0 1 2 3 4 5 6 7 8 9]
output_predict: [0 1]
output_proba: [[ 0.00190932  0.11212957  0.1092459   0.11262532      0.11150733  0.11208733
0.11156622  0.11043403  0.10747514  0.11101985]
[ 0.10991574  0.00220549  0.10944998  0.11288081  0.11178518   0.11234661
0.11182221  0.11065663  0.10770783  0.11122952]]

我已经阅读了this post,它提供了一个使用带有 decision_function() 的线性 SVM 的解决方案。但是,由于我的任务,我仍然必须专注于 SVM 的卡方内核。

有什么解决办法吗?

【问题讨论】:

既然你已经想通了,你还有什么问题? 我的问题是如何为卡方SVM的输出绘制PR曲线。谢谢:) 【参考方案1】:

作为documentation states,不能保证predict_probapredict 在SVC 上给出一致的结果。 您可以简单地使用decision_function。这对于线性和内核 SVM 都是如此。

【讨论】:

以上是关于scikit-learn 的 svm 的 predict_proba 的混淆概率的主要内容,如果未能解决你的问题,请参考以下文章

如何获得 scikit-learn SVM 分类器的所有 alpha 值?

目标的缩放导致 Scikit-learn SVM 回归崩溃

将 scikit-learn SVM 模型转换为 LibSVM

机器学习:SVM(scikit-learn 中的 SVM:LinearSVC)

用于文本分类的一类 SVM 模型(scikit-learn)

scikit-learn 中自定义内核 SVM 的交叉验证