ConfusionMatrixDisplay(Scikit-Learn)绘图标签超出范围

Posted

技术标签:

【中文标题】ConfusionMatrixDisplay(Scikit-Learn)绘图标签超出范围【英文标题】:ConfusionMatrixDisplay (Scikit-Learn) plot labels out of range 【发布时间】:2021-08-10 16:26:42 【问题描述】:

以下代码绘制了一个混淆矩阵:

from sklearn.metrics import ConfusionMatrixDisplay

confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.savefig("conf.png")

这个情节有两个问题。

    y 轴标签被切断(真实标签)。 x 标签也被切断了。 x 轴名称过长。

为了解决第一个问题,我尝试使用poof(bbox_inches='tight'),不幸的是它不适用于 sklearn。 在第二种情况下,我为2. 尝试了以下解决方案,这会导致情节完全扭曲。

总而言之,我正在努力解决这两个问题。

【问题讨论】:

尝试像import matplotlib.pyplot as plt plt.rcParams["figure.figsize"] = (15,10)这样的rcParams。所以改变宽度,高度。 试过了,没有任何改变 plt.rcParams["figure.figsize"] = (15,10) 是在 disp.plot 之前添加的吗?或者我的意思是需要在情节之前。 这实际上完成了工作。现在情节周围有一大片空白区域。我需要手动裁剪吗? 您可以调整 (15,10) 使其最适合标签。 【参考方案1】:

我认为最简单的方法是切换到 tight_layout 并添加 pad_inches= 一些东西。

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from numpy.random import default_rng

rand = default_rng()
y_true = rand.integers(low=0, high=7, size=500)
y_pred = rand.integers(low=0, high=7, size=500)


confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)

plt.tight_layout()
plt.savefig("conf.png", pad_inches=5)

结果:

【讨论】:

以上是关于ConfusionMatrixDisplay(Scikit-Learn)绘图标签超出范围的主要内容,如果未能解决你的问题,请参考以下文章

混淆矩阵颜色匹配数据大小而不是分类精度

不能将自定义非线性颜色图与 imshow 结合使用

SC接口光模块相关知识

不同 SC 系列的智能卡探测:清除 SC 状态的命令

为啥 sc.next() 或 sc.nextLine() 在循环中不起作用?

Spark:sc.textFiles() 与 sc.wholeTextFiles() 的区别