Python - 混淆矩阵维度的差异
Posted
技术标签:
【中文标题】Python - 混淆矩阵维度的差异【英文标题】:Python - difference in confusion matrix dimension 【发布时间】:2018-08-22 08:02:45 【问题描述】:我有一个关于混淆矩阵的问题。我使用交叉验证将 148 个实例拆分为两个数组 - 测试和训练。比我这样称呼:
def GenerateResult:
clf = OneVsRestClassifier(GaussianNB())
clf.fit(x_train, y_train)
predictions = clf.predict(x_test)
accuracy = accuracy_score(y_test, predictions)
confusion_mtrx = confusion_matrix(y_test, predictions)
这是 KFold 的循环 -> 我从上调用函数:
for train_idx, test_idx in pf.split(x_array):
x_train, x_test = x_array[train_idx], x_array[test_idx]
y_train, y_test = y_array[train_idx], y_array[test_idx]
acc, confusion= GenerateResult(x_train, x_test, y_train, y_test)
results['First'].append(acc)
confusion_dict['First'].append(confusion)
然后我对结果求和并计算平均值
np_gausian = np.asarray(results['gaussian'])
print("[First] Mean: ".format(np.mean(np_gausian)))
print(confusion_dict['gaussian'])
我有一个问题。在我的 148 个实例中,我有 4 个输出类,当我将该循环用于 KFold 时,我得到了两个不同的混淆矩阵。 第一个混淆矩阵 3x3:
[[36 1 1]
[15 17 1]
[ 0 0 3]]
第二个 4x4 :
[[ 0 2 0 0]
[ 0 41 2 0]
[ 0 12 16 0]
[ 0 0 1 0]]
我认为我有问题,因为在我的 148 实例中我有
1 - 2 类
2 类 - 81 个
3 级 - 61 次
4 - 4 类
所有类 - 148
我该怎么办?我怎样才能总结这个混淆矩阵?如果我更改 KFold 中的拆分数量怎么办?我尝试使用 Pandas,但我不知道该怎么做。请帮忙,我用的是sk-learn
【问题讨论】:
我认为您的问题源于这样一个事实,即第 1 类只有两个观察结果,当您拆分测试时,它们都属于第二类。尝试使用 StratifiedKFold 而不是 KFold,如果这不能帮助手动将其中一个观察结果从第 1 类移动到另一个折叠。 【参考方案1】:正如@KRKirov 在评论中指出的那样,其原因是由于 Kfold 交叉验证将数据拆分为折叠,某些类不存在于该折叠的测试集中。
例如,如果 class1 在y_test
中不存在,并且在predictions
中也没有预测,那么confusion_matrix
代码将自动推断数据中只存在三个类并根据那个。
您可以通过设置labels
param 来强制混淆矩阵使用所有类:-
标签:数组,形状 = [n_classes],可选
List of labels to index the matrix. This may be used to reorder or select a subset of labels. If none is given, those that appear at least once in y_true or y_pred are used in sorted order.
通过这样做:
confusion_mtrx = confusion_matrix(y_test, predictions,
labels = np.unique(y_array))
您需要将 y_array 或唯一标签从 y_array 传递给 GenerateResult() 方法。
【讨论】:
好的,可以了。也来自@KRKirov作品的回答。但我还有另一个问题 - 我应该如何总结我得到的混淆矩阵? 好的,我找到了。我们应该将所有混淆矩阵与该代码相加。谢谢大家 @Zarobiek 如果有帮助,请考虑投票/接受答案。以上是关于Python - 混淆矩阵维度的差异的主要内容,如果未能解决你的问题,请参考以下文章
混淆矩阵是什么?Python多分类的混淆矩阵计算及可视化(包含原始混淆矩阵及归一化的混淆矩阵):基于skelarn框架iris数据集