如何根据 GradientBoost 结果绘制热图?

Posted

技术标签:

【中文标题】如何根据 GradientBoost 结果绘制热图?【英文标题】:How to plot heatmap based on GradientBoost results? 【发布时间】:2019-08-18 11:25:41 【问题描述】:

我想根据 y_predict 和 y_train 的结果打印混淆矩阵的热图。

我有点卡住了,我已经查看了热图的 pandas 文档,但仍然不知道如何将它应用到我的结果中。我使用的数据集是关于收入的,有分类和数字数据。我已经应用了 GB 分类器并得到了结果。 唯一剩下的就是热图。

print(confusion_matrix(y_train,y_pred_train))
print(y_train)

这就是结果

Confusion Matrix:


[[14151   710]
 [ 1844  2831]]
Name: income, Length: 19536, dtype: int64

这是制作热图的尝试

import seaborn as sns
class_names = y_train, y_pred_train

def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names, 
    )
    fig = plt.figure(figsize=figsize)
    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return fig

返回的

NameError                                 Traceback (most recent call last)
<ipython-input-36-3bd0e9ee90a4> in <module>()
     18     plt.xlabel('Predicted label')
     19     return fig
---> 20 print(fig)

NameError: name 'fig' is not defined

那么当我在我的结果上制作混淆矩阵的热图时,我错过了什么?

【问题讨论】:

你可以用你的混淆矩阵来调用函数,例子见我的回答 print(fig)的错误是fig是在函数内部定义的,如果要在函数范围之外使用,可以用fig = print_confusion_matrix(...设置,因为它是从您的函数返回 return fig 【参考方案1】:

您可以使用混淆矩阵和类名列表作为参数调用print_confusion_matrix 函数:

def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names, 
    )
    fig = plt.figure(figsize=figsize)
    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return fig

confusion_matrix = np.array([[14151, 710], [1844, 2831]])
fig = print_confusion_matrix(confusion_matrix, ['0', '1'])

输出:

【讨论】:

非常感谢您的帮助!

以上是关于如何根据 GradientBoost 结果绘制热图?的主要内容,如果未能解决你的问题,请参考以下文章

如何在绘制多个热图时修复plt.tight_layout()错误

运用 R语言 pheatmap 包绘制热图进阶部分

Autodock Vina多对多分子对接并筛选结果绘制热图

如何用R语言绘制热图

案例演示 R语言绘制热图代码

一起来学习如何使用R语言绘制热图