类数,4,与 target_names 的大小,6 不匹配。尝试指定标签参数

Posted

技术标签:

【中文标题】类数,4,与 target_names 的大小,6 不匹配。尝试指定标签参数【英文标题】:Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter 【发布时间】:2020-01-14 21:25:35 【问题描述】:

当我尝试制作我的 CNN 模型的混淆矩阵时遇到了一些问题。当我运行代码时,它会返回一些错误,例如:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

Traceback (most recent call last):

  File "<ipython-input-102-82d46efe536a>", line 1, in <module>
    print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

  File "G:\anaconda_installation_file\lib\site-packages\sklearn\metrics\classification.py", line 1543, in classification_report
    "parameter".format(len(labels), len(target_names))

ValueError: Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

我已经搜索过解决这个问题,但仍然没有得到完美的解决方案。 我是这个领域的新手,有人可以帮助我吗? 谢谢。

from sklearn.metrics import classification_report,confusion_matrix
import itertools

Y_pred = model.predict(X_test)
print(Y_pred)
y_pred = np.argmax(Y_pred, axis=1)
print(y_pred)

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

【问题讨论】:

欢迎使用 SO,它通过按原样抛出我们所有的代码来工作; 错误之后出现的代码是多余的并且与问题无关(它永远不会执行)并且只会增加不必要的混乱(已删除)。请参阅How to create a Minimal, Complete, and Verifiable example。 【参考方案1】:

您应该更好地提出您的问题!我在做一些假设! 问题是:

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(垃圾)']

有 6 个类,您的模型只能预测 4 个会引发错误的类,因为混淆矩阵提供了 4 个类(它应该是 6x6 而不是 6x4)。 要纠正这个问题,只需提供类标签。 对于 ecample,如果有 3 个标签(在预测变量中)即 1,2,3

打印(分类报告(y_true,y_pred,标签=[1、2、3]))

请参阅此处的文档https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html

PS:

    您的模型表现不佳。

    您的数据集可能存在类别不平衡问题。

【讨论】:

【参考方案2】:

问题是您有 6 个标签名称:'class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)'

但是你的confusion_matrix中只有4个类,当你打印时:print(y_pred):你会得到带有0,1,2,3的数字的东西,或者当你print(y_test)你会从0,1,2,3得到数字时,它应该有助于删除:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

根据您的代码,不知何故您没有 6 个预测/测试类。

这里还有一个如何绘制混淆矩阵的示例:How can I plot a confusion matrix?

【讨论】:

以上是关于类数,4,与 target_names 的大小,6 不匹配。尝试指定标签参数的主要内容,如果未能解决你的问题,请参考以下文章

Java 应用程序中加载的类数中可能存在内存泄漏

第十四章聚类方法.14.2.4确定最佳聚类数

Linux——网络配置及命令

ping

用 R 将列的类数修改为分位数组

准确率,精确度(AP)与召回率(AR)