Sci-kit 分类阈值
Posted
技术标签:
【中文标题】Sci-kit 分类阈值【英文标题】:Sci-kit Classifying Thresholds 【发布时间】:2016-12-01 20:55:45 【问题描述】:所以我正在使用 scikit-learn 进行一些二元分类,现在我正在尝试 Logistic Regression 分类器。训练完分类器后,我打印出分类结果和它们在每个类中的概率:
logreg = LogisticRegression()
logreg.fit(X_train,y_train)
print logreg.predict(X_test)
print logreg.predict_proba(X_test)
所以我得到类似的东西:
[-1 1 1 -1 1 -1...-1]
[[ 8.64625237e-01 1.35374763e-01]
[ 3.57441028e-01 6.42558972e-01]
[ 1.67970096e-01 8.32029904e-01]
[ 9.20026249e-01 7.99737513e-02]
[ 1.20456011e-02 9.87954399e-01]
[ 6.48565595e-01 3.51434405e-01]...]
等等......所以看起来只要概率超过0.5,这就是对象被分类的原因。我正在寻找一种方法来调整这个数字,例如,属于 1 类的概率必须超过 0.7 才能被归类为此类。有没有办法做到这一点?我正在查看一些参数,例如“tol”和“重量”,但我不确定它们是否是我正在寻找的,或者它们是否正在工作......
【问题讨论】:
***.com/questions/31417487/…的可能重复 如果你得到了预测的概率(就像你所做的那样),那么做某事就很容易了 (probas[:,1]>=threshold).astype(int) 【参考方案1】:你可以像这样设置你的THRESHOLD
THRESHOLD = 0.7
preds = np.where(logreg.predict_proba(X_test)[:,1] > THRESHOLD, 1, 0)
请参考sklearn LogisticRegression and changing the default threshold for classification
【讨论】:
以上是关于Sci-kit 分类阈值的主要内容,如果未能解决你的问题,请参考以下文章