绘制阈值(precision_recall 曲线)matplotlib/sklearn.metrics

Posted

技术标签:

【中文标题】绘制阈值(precision_recall 曲线)matplotlib/sklearn.metrics【英文标题】:Plotting Threshold (precision_recall curve) matplotlib/sklearn.metrics 【发布时间】:2021-05-04 15:11:04 【问题描述】:

我正在尝试绘制精度/召回曲线的阈值。我只是使用 MNSIT 数据,其中的示例来自使用 scikit-learn、keras 和 TensorFlow 进行机器学习一书。试图训练模型来检测 5 的图像。我不知道你需要看多少代码。我已经为训练集制作了混淆矩阵,并计算了精度和召回值以及阈值。我已经绘制了 pre/rec 曲线,书中的示例说要添加轴标签、壁架、网格并突出显示阈值,但代码在书中我在下面放置星号的地方被切断了。除了如何让阈值出现在情节上之外,我能够弄清楚所有事情。我附上了一张书中图表与我所拥有的图表的图片。 这就是本书所展示的:

对比我的图表:

我无法显示带有两个阈值点的红色虚线。有谁知道我会怎么做?下面是我的代码:

from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)

plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()

我知道这里有很多关于 sklearn 的问题,但似乎没有一个涉及让红线出现。非常感谢您的帮助!

【问题讨论】:

您可以看到如何在两点之间绘制线,然后您可以指定坐标并绘制一条线。对于点,您可以使用散点图。 您可以按照提供的答案中的建议进行操作。我还建议看看作者的整个snippet。 @amiola 谢谢!!!这正是我正在解决的问题! 【参考方案1】:

您可以使用以下代码绘制水平线和垂直线:

plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')

【讨论】:

【参考方案2】:

这应该以确切的方式工作:

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
    threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
    
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold")
    plt.plot([threshold_80_precision, threshold_80_precision], [0., 0.8], "r:")
    plt.axis([-4, 4, 0, 1])
    plt.plot([-4, threshold_80_precision], [0.8, 0.8], "r:")
    plt.plot([-4, threshold_80_precision], [recall_80_precision, recall_80_precision], "r:")
    plt.plot([threshold_80_precision], [0.8], "ro") 
    plt.plot([threshold_80_precision], [recall_80_precision], "ro")
    plt.grid(True)
    plt.legend()
    plt.show()

我在尝试复制本书中的代码时遇到了这段代码。原来@ageron 将所有资源都放在了他的 github 页面上。你可以去看看here

【讨论】:

以上是关于绘制阈值(precision_recall 曲线)matplotlib/sklearn.metrics的主要内容,如果未能解决你的问题,请参考以下文章

Sklearn机器学习——ROC曲线ROC曲线的绘制和AUC面积运用ROC曲线找到最佳阈值

R语言使用pROC包绘制ROC曲线获取最优阈值(threshold)及最优阈值对应的置信区间

python 使用sklearn绘制roc曲线选取合适的分类阈值

ROC 曲线和 libsvm

PR详解及二分类的PR曲线绘制

从 ROC 曲线获取阈值