Sci-kit分类阈值

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Sci-kit分类阈值相关的知识,希望对你有一定的参考价值。

所以我正在使用scikit-learn做一些二进制分类,现在我正在尝试使用Logistic回归分类器。在训练分类器之后,我打印出分类结果以及它们在每个类中的概率:

logreg = LogisticRegression()
logreg.fit(X_train,y_train)
print logreg.predict(X_test)
print logreg.predict_proba(X_test)

所以我得到类似的东西:

[-1 1 1 -1 1 -1...-1]
[[  8.64625237e-01   1.35374763e-01]
 [  3.57441028e-01   6.42558972e-01]
 [  1.67970096e-01   8.32029904e-01]
 [  9.20026249e-01   7.99737513e-02]
 [  1.20456011e-02   9.87954399e-01]
 [  6.48565595e-01   3.51434405e-01]...]

等等...所以每当概率超过0.5时,它就像是被分类为的对象。我正在寻找一种方法来调整这个数字,以便例如,在第1类中的概率必须超过.7才能被分类。有没有办法做到这一点?我正在看一些已经像'tol'和'weight'的参数,但我不确定它们是不是我正在寻找的,或者它们是否在工作......

答案

你可以这样设置你的THRESHOLD

THRESHOLD = 0.7
preds = np.where(logreg.predict_proba(X_test)[:,1] > THRESHOLD, 1, 0)

请参考sklearn LogisticRegression and changing the default threshold for classification

以上是关于Sci-kit分类阈值的主要内容,如果未能解决你的问题,请参考以下文章

sci-kit 中使用的特性学习 DT 的实现

使用 Sci-Kit 学习对具有大型语料库的文本进行分类

如何用sci-kit learn识别误分类文本文件的ID/名称/标题

如何在 Sci-Kit Learn / Tensorflow 中检测对象大小?

使用 Sci-kit Learn SVM 时预测始终相同

如何更改二元分类的阈值