如何在用 scikit-learn / matplotlib 绘制的混淆矩阵中格式化 xticklabels?
Posted
技术标签:
【中文标题】如何在用 scikit-learn / matplotlib 绘制的混淆矩阵中格式化 xticklabels?【英文标题】:How to format xticklabels in a confusion matrix plotted with scikit-learn / matplotlib? 【发布时间】:2019-12-17 10:58:52 【问题描述】:由于我在网上找到了不同的代码示例,我已经用 scikit-learn / matplotlib 绘制了一个混淆矩阵,但我一直在寻找如何在 xticklabels 和主标题之间添加空格。如下图所示,绘图标题和 xticklabels 重叠(+ ylabel 'True' 被剪掉了)。
Link to my confusion matrix image
这是我使用的函数:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
PLOTS = '/plots/' # Output folder
def plt_confusion_matrix(y_test, y_pred, normalize=False, title="Confusion matrix"):
"""
Plots a nice confusion matrix.
:param y_test: list of predicted labels
:param y_pred: list of labels that should have been predicted.
:param normalize: boolean. If False, the plots shows the number of sentences predicted.
If True, shows the percentage of sentences predicted.
:param title: string. Title of the plot.
:return: Nothing but saves the plot as a PNG file and shows it.
"""
labels = list(set(y_pred))
cm = confusion_matrix(y_test, y_pred, labels)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm, cmap=plt.cm.binary, interpolation='nearest')
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig.suptitle(title, fontsize=14, wrap=True)
fig.colorbar(cax)
ax.set_xticklabels([''] + labels, rotation=45)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.subplots_adjust(hspace=0.6)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
plt.savefig(PLOTS + title)
plt.show()
我不得不旋转 xticklabels,因为它们太长,否则会水平重叠,而且我不得不包裹标题,因为它也太长,否则无法完全显示在图像中。
我在另一篇文章中看到 xticklabels 也可以放在图形下方(就像在this *** post 中一样),所以也许这可能是一个解决方案,但我不知道如何制作。
我该如何解决这个问题?
在标题和 xticklabels 之间添加一些空格 (让它们看起来完全是 btw); 或使 ylabel 'True' 可见 或移动图形下方的 xticklabels。编辑:我尝试了两种 geekzeus 解决方案,但均未成功...
geekzeus 的第一个解决方案的结果:See confusion matrix geekzeus 的第二个解决方案的结果:See confusion matrix【问题讨论】:
【参考方案1】:您可以使用参数x
和y
指定标题的位置。如果您调整y
的值,可以生成所需的图。
fig.suptitle(title, fontsize=14, wrap=True, y=1.2)
【讨论】:
【参考方案2】:这样做
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
#xaxisticks
ax.xaxis.set_ticklabels(['A', 'B'])
#yaxis ticks
ax.yaxis.set_ticklabels(['B', 'A'])
或 使用 seaborn 和 matplotlib,也可以直接提供 list 变量给 ticks
import seaborn as sns
import matplotlib.pyplot as plt
cm = confusion_matrix(true_classes, predicted_classes)
ax= plt.subplot()
sns.heatmap(cm, annot=True, ax = ax); #annot=True to annotate cells
# labels, title and ticks
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(['A', 'B'])
ax.yaxis.set_ticklabels(['B', 'A'])
【讨论】:
我认为它正在工作,因为 PyCharm 预览显示了一个很好的结果,但保存的图像仍然是相同的 :'( 这是我使用的代码:ax.set_xlabel('Predicted') ax .set_ylabel('True') ax.set_title('\n\n\n\n\n') ax.xaxis.set_ticklabels([''] + labels, rotation=45) ax.yaxis.set_ticklabels(['' ] + 标签) 我也用 seaborn 尝试了你的解决方案:xtick 标签在下面,但仍然切入保存的 PNG 文件(尽管在预览中没问题)。但是,ytick 标签是重叠的,我找不到如何将它们水平放置(rotation=90 不起作用)。更糟糕的是:图中显示的数字处于科学模式,如“1.9e+03”...... 我你还没有解决它..比提供数据和代码与可重现的例子和你需要的例子输出以上是关于如何在用 scikit-learn / matplotlib 绘制的混淆矩阵中格式化 xticklabels?的主要内容,如果未能解决你的问题,请参考以下文章
[机器学习与scikit-learn-2]:如何学习Scikit-learn