Scikit Learn - 如何绘制概率

Posted

技术标签:

【中文标题】Scikit Learn - 如何绘制概率【英文标题】:Scikit Learn - How to plot probabilities 【发布时间】:2018-01-24 16:23:20 【问题描述】:

我想绘制模型的预测概率。

plt.scatter(y_test, prediction[:,0])
plt.xlabel("True Values")
plt.ylabel("Predictions")
plt.show()

但是,我得到了类似上面的图表。哪种有意义,但我想更好地可视化概率分布。有没有一种方法可以让我的实际类为 0 或 1 并且预测介于 0 和 1 之间。

【问题讨论】:

【参考方案1】:

您可以根据真实值拆分值,然后绘制两个类的值的两个直方图,例如使用以下内容(至少如果您有一个 numpy 数组 arr_truearr_pred 这应该可以工作):

arr_true_0_indices = (y_test == 0.0)
arr_true_1_indices = (y_test == 1.0)

arr_pred_0 = prediction[arr_true_0_indices]
arr_pred_1 = prediction[arr_true_1_indices]

plt.hist(arr_pred_0, bins=40, label='True class 0', normed=True, histtype='step')
plt.hist(arr_pred_1, bins=40, label='True class 1', normed=True, histtype='step')
plt.xlabel('Network output')
plt.ylabel('Arbitrary units / probability')
plt.legend(loc='best')
plt.show()

这应该是这样的:

【讨论】:

【参考方案2】:

预测概率可用于可视化模型性能。真正的标签可以用颜色来表示。

试试这个例子:

from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

X, y = make_classification(n_samples=1000, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=1, shuffle=False)
from sklearn.linear_model import LogisticRegression

lr=LogisticRegression(random_state=0, solver='lbfgs', max_iter=10)
lr.fit(X, y)

prediction=lr.predict_proba(X)[:,1]

plt.figure(figsize=(15,7))
plt.hist(prediction[y==0], bins=50, label='Negatives')
plt.hist(prediction[y==1], bins=50, label='Positives', alpha=0.7, color='r')
plt.xlabel('Probability of being Positive Class', fontsize=25)
plt.ylabel('Number of records in each bucket', fontsize=25)
plt.legend(fontsize=15)
plt.tick_params(axis='both', labelsize=25, pad=5)
plt.show() 

【讨论】:

是否有 pyplot 设置可以堆叠两个历史类别而不是叠加?

以上是关于Scikit Learn - 如何绘制概率的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 Tensorflow 和 scikit-learn 绘制 ROC 曲线?

如何在用 scikit-learn / matplotlib 绘制的混淆矩阵中格式化 xticklabels?

如何在 scikit learn 中绘制逻辑回归的决策边界

如何在 Scikit-Learn 中绘制超过 10 倍交叉验证的 PR 曲线

如何使用 scikit-learn 和 matplotlib 为不平衡数据集绘制 SVC 分类?

使用 matplotlib 绘制 scikit learn 线性回归结果