多分类问题的处理策略和评估手段

Posted Laurence Geng

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了多分类问题的处理策略和评估手段相关的知识,希望对你有一定的参考价值。

1. 多分类问题的处理策略

多分类问题基本都是建立在二分类问题基础之上的,简单说就是:将多分类问题拆解成多个二分类问题去解决,具体来说,通常有两种策略:

  • One-Versus-The-Rest (OvR)

    One-Versus-The-Rest (OvR) 也叫 One-Versus-All(OvA):即每一个类别和所有其他类别做一次二分类,全部类别都做完后,就等于实现了多分类。一个有N种分类的问题使用此策略需要进行N次二分类处理

  • One-Versus-One(OvO)

    即每一个类别都和另一个类比做一次1V1的二分类,全部类别都分别和其他类别做完后,就等于实现了多分类。一个有N种分类的问题使用此策略需要进行N*(N-1)/2次二分类处理

上述两种策略还有显著差异的:使用OvR,单次训练的数据集会比较大(以MNIST数据集为例,每个二分类模型使用的是全体训练数据),但是训练的模型数量会比较少(共10个);使用OvO,单次训练的数据集会比较小(以MNIST数据集为例,每个二分类模型使用的是全体训练数据的20%),但是训练的模型数量会比较多(共45个)。

大多数的二分类算法比较适合OvR(1V其他),有一小部分二分类算法,比如SVM分类器在处理大型数据集时性能是很差的,它们更擅长在较小的数据集上训练更多的分类器,而不是在较大的数据集上训练较少的分类器。Sklearn在这方面做的非常智能,你无需手动设置,当Sklearn检测到你在使用二分类算法处理多分类问题时(根据标注数据的种类即可判断出当前是一个二分类问题还是多分类问题),它会根据你使用的算法自动选择OvO或OvR。

2. 多分类结果的评估手段

2.1 N维混淆矩阵

二分类问题的预测结果是一个2×2的混淆矩阵,一个多分类模型的预测结果将是一个N×N的混淆矩阵。通常,为了更好地查看结果,通过绘制热力图可以很好地辅助观察数据,为此,Sklearn还专门封装了一个类,用于展示多维混淆矩阵,以下是《Hands-On ML》一书中针对MNIST数据集(数字图片10分类问题)绘制的热力图:

说一下怎么看这个图,对于一个二分类的混淆矩阵它的行例代表的是“真”和“假”,而一个10X10的混淆矩阵应该这么看:以第一行为例,这一行是实际为数字0的所有图片,10个单元格分别把实际为数字0的图片预测成0, 1, 2,…9(10个类别)的样本数量,其他行以此类推。所以主对角线上的数字其实是TP,而FP、FN、TN其实是被拆分成了多个分量,怎么理解呢?我们以数字0的预测结果为例,它的TP、FP、FN、TN分别如下图所示:

2.2 归一化处理

虽然上图中主对角线上的TP数字比较醒目,但却很难看出预测错误和正确的比例,所以,一般会做一个“归一化”处理,处理逻辑是:用当前单元格的数据除以行或列的数据总和,如果是按行归一化,所得百分比是当前实际分类被预测成各种分类的比例,如果是按列归一化,所得百分比是当前预测分类实际为各个分类的比例。下图是按行归一化后的结果,比上图清晰了很多:

2.3 去除正确结果,聚焦错误

但在实际开发中,分析人员往往更关注的是“那些预测错误”的数据,以便进行算法调优,所以,人们通常希望把预测正确的数据(即TP)从图表中“抹去”,这样,还能变向“放大”错误数据的比值(因为分母变小了,只有错误数据和错误数据之间的比较了),对此,sklearn库用于绘制混淆矩阵的方法中专门提供了一个sample_weight参数,这是一个和要输出的混淆矩阵布局一一对应的矩阵,矩阵中的每一个元素是一个权重值,混淆矩阵中的值会和对应的权重值相乘,得到变换后的新矩阵,然后用于绘制图表。所以“抹去”预测正确的数据做法很简单:通过bool运算对预测结果(10X10的混淆矩阵)进行转换,转换为bool类型的矩阵:预测对的为False,预测错的为True。把这个结果作为权重传给from_predictions()的sample_weight参数,这样输出的矩阵会与权重矩阵对应位置上的权值相乘。由于True=1, False=0,所以相乘后,矩阵中预测正确的值都变成了0,如果同时再配合上归一化(注意,此时归一化行或列的总和都不包含正确的样本数量了,所以是错误和错误之间的比值,比值会放大),就可以得到实用性最高的一张图表了,以下是去除预测正确的数据后,分别按行和按列生成的混淆矩阵:

2.4 观察和解读数据

上面输出的两张图表其实给了我们很多信息,我们应该认真观察并分析原因,为后续的算法优化提供线索,这些重要的信息包括:

  1. 左侧图表显示:数字8在各个类别(0-9)中被错误识别的比例都很高,这说明由于数字8本身的“形状”,它就是很容易被识别错误
  2. 右侧图表显示:算法在对每一个数字的预测中,基本都有一到两个错误率比较高的数字,这说明这个数字本身和那一到两个数字的“形状”非常类似从而导致判定错误率比较高。例如:9和7 ,5和3,它的错误占和人们的直观感受是相符合的。

以上是关于多分类问题的处理策略和评估手段的主要内容,如果未能解决你的问题,请参考以下文章

Sklearn:评估 GridSearchCV 中 OneVsRestClassifier 的每个分类器的性能

评估指标与评分(下):多分类指标及其他

为多标签分类评估 DNNClassifier

sklearn多分类模型评测(LR, linearSVC, lightgbm)

sklearn-分类决策树

使用 sklearn 的 roc_auc_score 进行 OneVsOne 多分类?