python 可视化混淆矩阵

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python 可视化混淆矩阵相关的知识,希望对你有一定的参考价值。

import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import classification_report

def pretty_print_conf_matrix(y_true, y_pred, 
                             classes,
                             normalize=False,
                             title='Confusion matrix',
                             cmap=plt.cm.Blues):
    """
    Mostly stolen from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

    Normalization changed, classification_report stats added below plot
    """

    cm = confusion_matrix(y_true, y_pred)

    # Configure Confusion Matrix Plot Aesthetics (no text yet) 
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=14)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    plt.ylabel('True label', fontsize=12)
    plt.xlabel('Predicted label', fontsize=12)

    # Calculate normalized values (so all cells sum to 1) if desired
    if normalize:
        cm = np.round(cm.astype('float') / cm.sum(),2) #(axis=1)[:, np.newaxis]

    # Place Numbers as Text on Confusion Matrix Plot
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black",
                 fontsize=12)


    # Add Precision, Recall, F-1 Score as Captions Below Plot
    rpt = classification_report(y_true, y_pred)
    rpt = rpt.replace('avg / total', '      avg')
    rpt = rpt.replace('support', 'N Obs')

    plt.annotate(rpt, 
                 xy = (0,0), 
                 xytext = (-50, -140), 
                 xycoords='axes fraction', textcoords='offset points',
                 fontsize=12, ha='left')    

    # Plot
    plt.tight_layout()
pretty_print_conf_matrix(Y_validation, predictions,
                         class_labels, normalize=True, 
                         title='Confusion matrix',
                         cmap=plt.cm.Blues)

以上是关于python 可视化混淆矩阵的主要内容,如果未能解决你的问题,请参考以下文章

python使用sklearn的ConfusionMatrixDisplay来可视化混淆矩阵

python 可视化混淆矩阵

Python混淆矩阵可视化:plt.colorbar函数自定义颜色条的数值标签配置不同情况下颜色条的数值范围以及数据类型(整型浮点型)

R语言自定义多分类混淆矩阵可视化函数(mutlti class confusion matrix)R语言多分类混淆矩阵可视化

R语言使用caret包的confusionMatrix函数计算混淆矩阵使用编写的自定义函数可视化混淆矩阵(confusion matrix)

sklearn使用投票器VotingClassifier算法构建多模型融合的硬投票器分类器(hard voting)并计算融合模型的混淆矩阵可视化混淆矩阵(confusion matrix)