如何用python画好confusion matrix

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何用python画好confusion matrix相关的知识,希望对你有一定的参考价值。

参考技术A

在做分类的时候,经常需要画混淆矩阵,下面我们使用python的matplotlib包,scikit-learning机器学习库也同样提供了例子:, 但是这样的图并不能满足我们的要求,

首先是刻度的显示是在方格的中间,这需要隐藏刻度,其次是如何把每个label显示在每个方块的中间, 其次是如何在每个方格中显示accuracy数值, 最后是如何在横坐标和纵坐标显示label的名字,在label name比较长的时候,如何处理显示问题。

直接贴上代码:

[python] view plain copy

    '''''compute confusion matrix 

    labels.txt: contain label name. 

    predict.txt: predict_label true_label 

    '''  

    from sklearn.metrics import confusion_matrix  

    import matplotlib.pyplot as plt  

    import numpy as np  

    #load labels.  

    labels = []  

    file = open('labels.txt', 'r')  

    lines = file.readlines()  

    for line in lines:  

    labels.append(line.strip())  

    file.close()  

    y_true = []  

    y_pred = []  

    #load true and predict labels.  

    file = open('predict.txt', 'r')  

    lines = file.readlines()  

    for line in lines:  

    y_true.append(int(line.split(" ")[1].strip()))  

    y_pred.append(int(line.split(" ")[0].strip()))  

    file.close()  

    tick_marks = np.array(range(len(labels))) + 0.5  

    def plot_confusion_matrix(cm, title='Confusion Matrix', cmap = plt.cm.binary):  

    plt.imshow(cm, interpolation='nearest', cmap=cmap)  

    plt.title(title)  

    plt.colorbar()  

    xlocations = np.array(range(len(labels)))  

    plt.xticks(xlocations, labels, rotation=90)  

    plt.yticks(xlocations, labels)  

    plt.ylabel('True label')  

    plt.xlabel('Predicted label')  

    cm = confusion_matrix(y_true, y_pred)  

    print cm  

    np.set_printoptions(precision=2)  

    cm_normalized = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]  

    print cm_normalized  

    plt.figure(figsize=(12,8), dpi=120)  

    #set the fontsize of label.  

    #for label in plt.gca().xaxis.get_ticklabels():  

    #    label.set_fontsize(8)  

    #text portion  

    ind_array = np.arange(len(labels))  

    x, y = np.meshgrid(ind_array, ind_array)  

    for x_val, y_val in zip(x.flatten(), y.flatten()):  

    c = cm_normalized[y_val][x_val]  

    if (c > 0.01):  

    plt.text(x_val, y_val, "%0.2f" %(c,), color='red', fontsize=7, va='center', ha='center')  

    #offset the tick  

    plt.gca().set_xticks(tick_marks, minor=True)  

    plt.gca().set_yticks(tick_marks, minor=True)  

    plt.gca().xaxis.set_ticks_position('none')  

    plt.gca().yaxis.set_ticks_position('none')  

    plt.grid(True, which='minor', linestyle='-')  

    plt.gcf().subplots_adjust(bottom=0.15)  

    plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')  

    #show confusion matrix  

    plt.show()  

    结果如下图所示:

    阅读全文

    版权声明:本文为博主原创文章,未经博主允许不得转载。

    目前您尚未登录,请 登录 或 注册 后进行评论

    linchunmian

    2017-05-08 2

以上是关于如何用python画好confusion matrix的主要内容,如果未能解决你的问题,请参考以下文章

如何用Visio画venn(维恩)图

如何用DXP画数显卡尺板

如何用AltiumDesigner绘制STC89C51单片机原理图

如何用word画一幅二叉树图啊?

如何用层(DIV)给一张图片上的固定文字添加超链接?

如何用python抓取电话