Python - 混淆矩阵维度的差异

Posted

技术标签:

【中文标题】Python - 混淆矩阵维度的差异【英文标题】:Python - difference in confusion matrix dimension 【发布时间】:2018-08-22 08:02:45 【问题描述】:

我有一个关于混淆矩阵的问题。我使用交叉验证将 148 个实例拆分为两个数组 - 测试和训练。比我这样称呼:

def GenerateResult:
   clf = OneVsRestClassifier(GaussianNB())
   clf.fit(x_train, y_train)
   predictions = clf.predict(x_test)
   accuracy = accuracy_score(y_test, predictions)
   confusion_mtrx = confusion_matrix(y_test, predictions)

这是 KFold 的循环 -> 我从上调用函数:

for train_idx, test_idx in pf.split(x_array):
       x_train, x_test = x_array[train_idx], x_array[test_idx]
       y_train, y_test = y_array[train_idx], y_array[test_idx]
       acc, confusion= GenerateResult(x_train, x_test, y_train, y_test)
       results['First'].append(acc)
       confusion_dict['First'].append(confusion)

然后我对结果求和并计算平均值

np_gausian = np.asarray(results['gaussian'])
print("[First] Mean: ".format(np.mean(np_gausian)))

print(confusion_dict['gaussian'])

我有一个问题。在我的 148 个实例中,我有 4 个输出类,当我将该循环用于 KFold 时,我得到了两个不同的混淆矩阵。 第一个混淆矩阵 3x3:

[[36  1  1]

 [15 17  1]

 [ 0  0  3]]

第二个 4x4 :

[[ 0  2  0  0]

 [ 0 41  2  0]

 [ 0 12 16  0]

 [ 0  0  1  0]]

我认为我有问题,因为在我的 148 实例中我有

1 - 2 类

2 类 - 81 个

3 级 - 61 次

4 - 4 类

所有类 - 148

我该怎么办?我怎样才能总结这个混淆矩阵?如果我更改 KFold 中的拆分数量怎么办?我尝试使用 Pandas,但我不知道该怎么做。请帮忙,我用的是sk-learn

【问题讨论】:

我认为您的问题源于这样一个事实,即第 1 类只有两个观察结果,当您拆分测试时,它们都属于第二类。尝试使用 StratifiedKFold 而不是 KFold,如果这不能帮助手动将其中一个观察结果从第 1 类移动到另一个折叠。 【参考方案1】:

正如@KRKirov 在评论中指出的那样,其原因是由于 Kfold 交叉验证将数据拆分为折叠,某些类不存在于该折叠的测试集中。

例如,如果 class1 在y_test 中不存在,并且在predictions 中也没有预测,那么confusion_matrix 代码将自动推断数据中只存在三个类并根据那个。

您可以通过设置labels param 来强制混淆矩阵使用所有类:-

标签:数组,形状 = [n_classes],可选

List of labels to index the matrix. This may be used to reorder or
select a subset of labels. If none is given, those that appear at
least once in y_true or y_pred are used in sorted order.

通过这样做:

confusion_mtrx = confusion_matrix(y_test, predictions, 
                                 labels = np.unique(y_array))

您需要将 y_array 或唯一标签从 y_array 传递给 GenerateResult() 方法。

【讨论】:

好的,可以了。也来自@KRKirov作品的回答。但我还有另一个问题 - 我应该如何总结我得到的混淆矩阵? 好的,我找到了。我们应该将所有混淆矩阵与该代码相加。谢谢大家 @Zarobiek 如果有帮助,请考虑投票/接受答案。

以上是关于Python - 混淆矩阵维度的差异的主要内容,如果未能解决你的问题,请参考以下文章

Weka中决策树和混淆矩阵中正确/错误分类实例之间的差异

混淆矩阵是什么?Python多分类的混淆矩阵计算及可视化(包含原始混淆矩阵及归一化的混淆矩阵):基于skelarn框架iris数据集

Python使用pandas的crosstab函数计算混淆矩阵并使用Seaborn可视化混淆矩阵实战

python评分卡之LR及混淆矩阵、ROC

python绘制混淆矩阵

Python hmmlearn中的混淆矩阵是怎么表示的