寻找简单阈值分类器的多类阈值

Posted

技术标签:

【中文标题】寻找简单阈值分类器的多类阈值【英文标题】:Finding mutliclass thresholds for simple threshold classifier 【发布时间】:2016-06-16 22:59:59 【问题描述】:

我有一个函数,它返回一个实例的数值,我稍后使用这个数值将实例分类为三个类别之一。类别是相对可分离的,见下图(三种颜色代表三个不同的类别)。

所以在这里我想要两个阈值,k1k2,这样k1 左边的所有东西都被分类为红色,k2 右边的所有东西都被分类为蓝色,中间的所有东西都被分类为绿色.

我从基于this 解决方案的 Kadane 算法的修改版本开始。我首先按值对所有(颜色,值)元组进行排序,然后生成一个数组,其中所有绿色分类的值都为 1,非绿色为 -1。所以我会得到一个看起来像这样的数组:

 [-1, -1, -1, -1, 1, -1, -1, ..., 1, 1, 1, -1, 1, ..., -1, -1, -1, -1]

也就是说,最初有很多 -1(红色),中间有很多绿色,最后大部分是蓝色。现在,通过运行 Kadane 算法,我会得到最优分割吗?

这是我测试的代码:

import java.util.*;

public class Kadanes 
    private static Color[] correctClasses = new Color[]Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN;
    private static double[] predictedValues = new double[]0.0, 0.34, 2.0, 2.67, 7.53, -0.04, 2.0, -3.55, 3.78, 0.33, 3.0, -0.21, 1.41, -0.37, 0.84, 3.94, 8.34, 0.0, -1.39, 3.0, -1.63, 0.0, 3.0, 1.26, 0.0, 0.0, 0.0, 0.0, 0.61, 0.0, 3.34, 0.57, -1.05, 0.63, 0.0, 0.71, 0.0, 2.34, -0.41, -1.77, 3.0, 0.62, 0.93, 1.55, 2.0, 8.0, -1.55, 5.75, 0.0, 0.0, -0.25, 0.0, 1.0, 10.51, 0.0, 0.47, 0.78, -1.08, -1.51, 1.0, 1.0, 0.0, 4.33, -0.6, 0.37, 6.0, 1.16, -4.07, 2.0, 0.91, -0.05, 1.78, 0.0, 0.0, 0.0, 0.0, 0.0, 1.64, 1.55, 4.44, 2.78, 1.47, 3.75, 0.0, 7.59, 0.0, 0.94, 2.46, -0.23, -0.2, 0.0, 0.39, -2.31, 3.0, -1.15, 2.0, -0.76, -1.33, 0.0, 0.61, 0.77, -1.77, -1.08, 0.0, -3.2, 3.46, 1.0, 0.0, 0.0, 3.33, 0.0, 0.0, 2.81, 0.0, 0.0, 0.0, 3.0, 0.0, -0.88, 1.65, -1.09, -0.35, 0.0, 0.0, 5.0, 0.0, 2.88, -0.72, 0.87, 7.0, 7.48, -1.98, 1.0, 1.11, 4.0, 1.53, 0.0, 8.07, 1.54, 4.23, 0.0, -0.73, 6.61, 0.07, 0.0, -4.32, -1.77, 2.05, -1.08, 4.3, 1.61, 2.96, 3.0, 0.0, 3.66, 0.0, 0.0, 0.05, -0.77, -1.0, 0.0, 5.43, 2.12, -1.55, 2.3, 0.0, 3.6, 0.0, 0.0, -10.21, 2.0, 0.55, -0.63, 0.0, 1.0, 0.0, 0.0, 1.28, 3.0, 0.0, 0.44, 1.27, 2.12, 2.17, 1.76, -1.9, 5.42, 1.0, 3.76, -3.55, -0.82, 0.0, 0.11, -1.7, -0.33, 0.0, 0.0, -2.01, 0.0, 3.52, 2.0, 6.0, 0.92, 7.22, 0.0, 0.0, 0.0, 0.0, 0.36, -1.77, 0.0, -3.32, -0.91, 2.69, -0.86, -0.27, 3.28, -1.02, 0.41, -0.6, 2.61, 0.0, 0.36, 0.0, 0.91, 0.0, -2.82, 0.0, -1.77, 0.0, -0.33, 3.94, -2.55, 8.0, 3.29, 2.7, -4.4, 9.0, 0.0, 2.81, -0.23, -2.51, 2.0, -0.19, 0.0, 0.0, 0.0, 0.8, 8.33, 0.0, 0.59, 0.0, 0.41, 0.0, 0.8, 1.7, 3.27, 0.0, 0.34, -1.83, 0.0, -1.0, 0.29, 3.71, -0.44, -0.59, 1.25, 2.3, -1.56, 0.0, 6.21, -0.68, 0.0, 0.0, -0.3, 0.0, 1.0, 0.86, 0.0, 0.0, 0.0, 0.0, 0.41, 1.91, -0.17, -0.77, 1.0, 3.0, 2.0, 3.0, -0.71, 0.0, 0.62, 0.0, 2.54, 1.14, 0.0, 0.0, 3.27, 0.0, 0.96, -0.33, 0.0, 0.0, 1.91, -0.2, 0.0, 0.0, 0.6, 0.0, -0.82, 1.0, -0.54, 6.52, -2.48, 2.0, 0.0, 0.0, 1.61, 0.0, 0.0, 0.0, -0.17, 0.0, 1.0, -5.36, 2.73, 0.0, 7.97, 3.67, 0.0, -0.88, 0.93, 0.0, 3.0, -1.03, -0.64, 2.78, 0.0, 1.0, 3.0, 0.0, 0.46, 0.0, -0.63, 0.0, 4.0, 4.0, 1.61, 0.0, 0.0, 1.07, 0.0, 1.0, 18.39, -1.82, 0.0, 0.86, -0.42, -1.77, -0.61, 0.0, 0.68, -3.13, 0.53, 0.0, 3.0, 0.0, 2.47, 0.0, -1.74, 5.31, 0.0, 0.3, 0.0, 0.0, 4.0, 1.0, 0.64, 1.0, 0.0, -1.77, 3.31, -1.77, -0.43, -3.55, 0.94, 8.59, 0.0, 1.81, 3.69, -1.77, -0.32, 0.0, 3.0, 1.93, -1.47, 1.0, 3.21, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0, -0.39, 0.0, 1.0, 0.0, 1.98, 0.0, 0.0, 7.45, 0.72, 0.34, 0.0, 0.35, 0.0, -2.74, 0.28, 4.0, 3.0, -0.91, -4.43, 0.0, 2.28, 3.0, -2.5, -2.66, 2.0, -0.66, 3.0, 11.06, 1.43, 3.0, 0.0, -0.79, 6.3, 0.94, 3.92, -4.43, 5.14, -2.35, 8.83, 1.04, 2.6, 5.0, 3.72;

    private static List<Tuple> previousResults = new ArrayList<>();
    static 
        for(int i=0; i<correctClasses.length; i++) 
            previousResults.add(new Tuple(correctClasses[i], predictedValues[i]));
        
    


    public static void main(String[] args) 
        double[] exampleThresholds = new double[]-1.65, 1.65;
        double[] thresholds = getThreshold();
        System.out.println(Arrays.toString(thresholds));

        System.out.println("Example threshold accuracy: " + getAccuracy(exampleThresholds));
        System.out.println("Optimal threshold accuracy: " + getAccuracy(thresholds));
    


    private static double[] getThreshold() 
        Collections.sort(previousResults, Collections.reverseOrder());

        int max_so_far = 0;
        int max_ending_here = 0;
        int max_start_index = 0;
        int startIndex = 0;
        int max_end_index = -1;

        for(int i = 0; i < previousResults.size(); i++) 
            int currentElementScore = (previousResults.get(i).correct == Color.GREEN ? 1 : -1);
            if(max_ending_here + currentElementScore < 0) 
                startIndex = i+1;
                max_ending_here = 0;
             else 
                max_ending_here += currentElementScore;
            

            if(max_ending_here > max_so_far) 
                max_so_far = max_ending_here;
                max_start_index = startIndex;
                max_end_index = i;
            
        

        double lowThreshold = getAvgValue(max_start_index-1, max_start_index);
        double highThreshold = getAvgValue(max_end_index, max_end_index+1);

        return new double[]lowThreshold, highThreshold;
    


    private static double getAccuracy(double[] thresholds) 
        int numCorrectlyClassified = 0;
        for(int i=0; i<correctClasses.length; i++) 
            Color predictedClassification = classify(predictedValues[i], thresholds[0], thresholds[1]);
            if(predictedClassification == correctClasses[i]) 
                numCorrectlyClassified++;
            
        

        return (double) numCorrectlyClassified / correctClasses.length;
    

    private static Color classify(double value, double lowThresh, double highThresh) 
        if(value < lowThresh) return Color.RED;
        if(value > highThresh) return Color.BLUE;
        return Color.GREEN;
    


    private static double getAvgValue(int index1, int index2) 
        if(index1 < 0) 
            return Double.NEGATIVE_INFINITY;
         else if (index2 >= previousResults.size()) 
            return Double.POSITIVE_INFINITY;
        

        return (previousResults.get(index1).predicted + previousResults.get(index2).predicted) / 2;
    


    static class Tuple implements Comparable<Tuple> 
        private Color correct;
        private double predicted;

        Tuple(Color correct, double predicted) 
            this.correct = correct;
            this.predicted = predicted;
        

        public String toString() 
            return "[" + correct.name() + ", " + predicted + "]";
        

        @Override
        public int compareTo(Tuple o) 
            double diff = o.predicted - predicted;
            return diff != 0 ? (int) Math.signum(diff) : correct.compareTo(o.correct);
        
    

    enum Color 
        BLUE, GREEN, RED
    

我得到的输出是:

[0.0, 0.0]
Example threshold accuracy: 0.5602678571428571
Optimal threshold accuracy: 0.49107142857142855

所以它找到的最佳阈值刚好在 0.0 范围内,我只输入了一个性能更好的快速示例阈值。是实现错误还是无法使用 Kadane 的算法来解决这个简单的问题,如果不是,我可以使用哪种算法?

【问题讨论】:

我认为您需要定义“最优”的含义。 @GordonLinoff 最大精度 【参考方案1】:

你不能使用 Kadane 算法来解决这个问题,因为它优化了果岭的数量。 假设你有类似的东西:

1, 1, -1, 1, -1, 1,.., -1, 1, 1, 1

该算法希望最大化总和并取第一个到一个和最后一个 3,因为中间部分的总和仅为 -1。

因此,我将使用 N logN 算法来查找阈值。 首先对数组进行排序。 预计算绿色、红色和蓝色的数量介于 0 和 x 之间的部分计数数组(对于所有 x)。 使用第一个阈值遍历已排序的数组,并使用准确度作为指导指标,对第二个阈值进行二分搜索。要计算准确性,您可以使用预先计算的部分计数。您需要查看指标增加的方式。

可能有一些极端情况。如果您负担得起 N^2,只需尝试所有阈值并使用预先计算的数组来加快评估速度。

【讨论】:

我认为 Kadane 的算法会起作用的原因是因为范围有些可分离,红色和蓝色之间几乎没有重叠,所以优化绿色的数量不会也优化其他的吗?我不太明白你的第二段,Kadane选择了一个范围,所以即使中间部分的总和为0,但开始部分和结束部分> 0,所以它会包括整个范围? @Limon 是的,我的意思是因为中间部分是 0,所以解决方案将包括整个范围,这可能不是您想要的。

以上是关于寻找简单阈值分类器的多类阈值的主要内容,如果未能解决你的问题,请参考以下文章

更改随机森林分类器的阈值

使用 3 个最高概率的多类分类器的性能

如何计算投票集成分类器的 AUC(曲线下面积)?

Spark,MLlib:调整分类器识别阈值

如何获得经过训练的 LDA 分类器的特征权重

使用 Scikit-Learn API 时如何调整 XGBoost 分类器中的概率阈值