scikit学习随机森林分类器概率阈值
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了scikit学习随机森林分类器概率阈值相关的知识,希望对你有一定的参考价值。
我用的是 sklearn RandomForestClassifier(随机森林分类器) 的预测任务。
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
model.predict_proba(x_test)
有171个班级需要预测,我只想预测那些班级,其中的 predict_proba(class)
是至少90%。下面的一切都应该设置为 0
.
例如,给定以下内容。
1 2 3 4 5 6 7
0 0.0 0.0 0.1 0.9 0.0 0.0 0.0
1 0.2 0.1 0.1 0.3 0.1 0.0 0.2
2 0.1 0.1 0.1 0.1 0.1 0.4 0.1
3 1.0 0.0 0.0 0.0 0.0 0.0 0.0
我的预期输出是:
0 4
1 0
2 0
3 1
答案
你可以使用 numpy.argwhere 如下。
from sklearn.ensemble import RandomForestClassifier
import numpy as np
model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
preds = model.predict_proba(x_test)
#preds = np.array([[0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.0],
# [ 0.2, 0.1, 0.1, 0.3, 0.1, 0.0, 0.2],
# [ 0.1 ,0.1, 0.1, 0.1, 0.1, 0.4, 0.1],
# [ 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
r = np.zeros(preds.shape[0], dtype=int)
t = np.argwhere(preds>=0.9)
r[t[:,0]] = t[:,1]+1
r
array([4, 0, 0, 1])
另一答案
你可以使用列表理解。
import numpy as np
# dummy predictions - 3 samples, 3 classes
pred = np.array([[0.1, 0.2, 0.7],
[0.95, 0.02, 0.03],
[0.08, 0.02, 0.9]])
# first, keep only entries >= 0.9:
out_temp = np.array([[x[i] if x[i] >= 0.9 else 0 for i in range(len(x))] for x in pred])
out_temp
# result:
array([[0. , 0. , 0. ],
[0.95, 0. , 0. ],
[0. , 0. , 0.9 ]])
out = [0 if not x.any() else x.argmax()+1 for x in out_temp]
out
# result:
[0, 1, 3]
以上是关于scikit学习随机森林分类器概率阈值的主要内容,如果未能解决你的问题,请参考以下文章
使用 Scikit-Learn API 时如何调整 XGBoost 分类器中的概率阈值