为啥这个混淆矩阵(matplotlib)在 Jupyter Notebook 中看起来被压扁了? [复制]

Posted

技术标签:

【中文标题】为啥这个混淆矩阵(matplotlib)在 Jupyter Notebook 中看起来被压扁了? [复制]【英文标题】:Why does this confusion matrix (matplotlib) look squished in Jupyter Notebook? [duplicate]为什么这个混淆矩阵(matplotlib)在 Jupyter Notebook 中看起来被压扁了? [复制] 【发布时间】:2020-01-16 03:55:38 【问题描述】:

我正在为混淆矩阵运行下面的代码。在我重置笔记本内核之前,输出看起来很棒。我没有更改代码,但现在它看起来被压扁了(图 1)。当我删除 plt.yticks 行时它会更正(图 2),但我想要这些标签。这可能很简单,但我是 Python 新手。

import itertools
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion Matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Source: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    # Plot the confusion matrix
    plt.figure(figsize = (6, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, size = 25)
    plt.colorbar(aspect=5)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, size = 12)
    plt.yticks(tick_marks, classes, size = 12)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.

    # Labeling the plot
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), fontsize = 20,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.grid(False)
    plt.tight_layout()
    plt.ylabel('Actual label', size = 15)
    plt.xlabel('Predicted label', size = 15)

cm = confusion_matrix(y_test, y_pred)
plot_confusion_matrix(cm, classes = ['Good Mental Health', 'Poor Mental Health'],
                      title = 'Confusion Matrix')

【问题讨论】:

【参考方案1】:

尝试在代码末尾添加这些行。

plt.tight_layout()
plt.show()

仅此一项就应该有很大帮助。更多建议:

1) 我认为您要求的是 6x6 英寸的数字,而这个空间包括标签。更大的数字可能会有所帮助。

2) 您可以尝试改进所需空间的使用方式。我肯定会要求标签在两个不同的行中:我猜你有

 tick_marks = ['good mental health', 'poor mental health']

在您的代码中的某处。正在做

 tick_marks = ['good \nmental health', 'poor \nmental health']

可能也有帮助。

3) 改善空间使用方式的另一种方法是旋转 ylabel:

 plt.yticks(tick_marks, classes, size = 12, rotation='vertical')

你应该尝试不同的组合,看看会发生什么。

【讨论】:

感谢您的提示。我已经在代码中有plt.tight_layout()plt.show()。我发现添加plt.ylim([1.5, -.5]) 解决了这个问题,尽管我仍然对为什么它在当天早些时候工作正常然后突然停止感到困惑。哦,好吧,它已经修复了! 你应该知道你第一次做了什么。一般来说,你可以相信通过清理内核并从一开始就运行它会发生什么。这会一次又一次地发生。 我在新的 matplotlib 版本中遇到了同样的问题,只有 plt.ylim([1.5, -.5]) 为我解决了这个问题

以上是关于为啥这个混淆矩阵(matplotlib)在 Jupyter Notebook 中看起来被压扁了? [复制]的主要内容,如果未能解决你的问题,请参考以下文章

python是不是有绘制混淆矩阵的函数,怎么来实现

为啥混淆矩阵中的总数与输入的数据不同?

python绘制混淆矩阵

TensorFlow:创建混淆矩阵时无法将图像转换为浮点数

标记为 TP、TN、FP、FN 的值的混淆矩阵

无法打印正确的混淆矩阵,并且热图中的值也在示例 2e+2、e+4 等中打印