Mahout - 简单的分类问题

Posted

技术标签:

【中文标题】Mahout - 简单的分类问题【英文标题】:Mahout - Simple classification issue 【发布时间】:2012-06-26 17:10:33 【问题描述】:

我正在尝试构建一个简单的模型,该模型可以将点分类为 2D 空间的 2 个分区

    我通过指定几个点及其所属的分区训练模型。 我使用模型预测测试点可能落入的组(分类)

很遗憾,我没有得到预期的答案。我是在代码中遗漏了什么还是我做错了什么?

public class SimpleClassifier 

    public static class Point
        public int x;
        public int y;

        public Point(int x,int y)
            this.x = x;
            this.y = y;
        

        @Override
        public boolean equals(Object arg0) 
            Point p = (Point)  arg0;
            return( (this.x == p.x) &&(this.y== p.y));
        

        @Override
        public String toString() 
            // TODO Auto-generated method stub
            return  this.x + " , " + this.y ; 
        
    

    public static void main(String[] args) 

        Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();

        points.put(new Point(0,0), 0);
        points.put(new Point(1,1), 0);
        points.put(new Point(1,0), 0);
        points.put(new Point(0,1), 0);
        points.put(new Point(2,2), 0);


        points.put(new Point(8,8), 1);
        points.put(new Point(8,9), 1);
        points.put(new Point(9,8), 1);
        points.put(new Point(9,9), 1);


        OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
        learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
        learningAlgo.learningRate(50);

        //learningAlgo.alpha(1).stepOffset(1000);

        System.out.println("training model  \n" );
        for(Point point : points.keySet())
            Vector v = getVector(point);
            System.out.println(point  + " belongs to " + points.get(point));
            learningAlgo.train(points.get(point), v);
        

        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(2);
        v.set(0, 0.5);
        v.set(1, 0.5);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = " );
        System.out.println("no of categories = " + learningAlgo.numCategories());
        System.out.println("no of features = " + learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + r.get(0));
        System.out.println("Probability of cluster 1 = " + r.get(1));

    

    public static Vector getVector(Point point)
        Vector v = new DenseVector(2);
        v.set(0, point.x);
        v.set(1, point.y);

        return v;
    

输出:

ans = 
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722

99% 的输出显示cluster 1 的概率更高。 为什么?

【问题讨论】:

@sean-owen 你能帮我解决这个问题吗? 请发布预期输出 【参考方案1】:

问题是您没有包含偏差(截距)项,它始终为 1。 您需要将偏差项 (1) 添加到您的点类中。

这是许多机器学习经验丰富的人犯的一个非常基本的错误。花一些时间学习理论可能是个好主意。 Andrew Ng's lectures 是一个学习的好地方。

要让您的代码提供预期的输出,需要更改以下内容。

    添加了偏差项。 学习参数太高。改为 10

现在您将获得第 0 类的 P(0)=0.9999。

这是一个给出正确结果的完整工作示例:

import java.util.HashMap;
import java.util.Map;

import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;


class Point
    public int x;
    public int y;

    public Point(int x,int y)
        this.x = x;
        this.y = y;
    

    @Override
    public boolean equals(Object arg0) 
        Point p = (Point)  arg0;
        return( (this.x == p.x) &&(this.y== p.y));
    

    @Override
    public String toString() 
        return  this.x + " , " + this.y ; 
    


public class SimpleClassifier 



    public static void main(String[] args) 

            Map<Point,Integer> points = new HashMap<Point, Integer>();

            points.put(new Point(0,0), 0);
            points.put(new Point(1,1), 0);
            points.put(new Point(1,0), 0);
            points.put(new Point(0,1), 0);
            points.put(new Point(2,2), 0);

            points.put(new Point(8,8), 1);
            points.put(new Point(8,9), 1);
            points.put(new Point(9,8), 1);
            points.put(new Point(9,9), 1);


            OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
            learningAlgo =  new OnlineLogisticRegression(2, 3, new L1());
            learningAlgo.lambda(0.1);
            learningAlgo.learningRate(10);

            System.out.println("training model  \n" );

            for(Point point : points.keySet())

                Vector v = getVector(point);
                System.out.println(point  + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            

            learningAlgo.close();

            Vector v = new RandomAccessSparseVector(3);
            v.set(0, 0.5);
            v.set(1, 0.5);
            v.set(2, 1);

            Vector r = learningAlgo.classifyFull(v);
            System.out.println(r);

            System.out.println("ans = " );
            System.out.println("no of categories = " + learningAlgo.numCategories());
            System.out.println("no of features = " + learningAlgo.numFeatures());
            System.out.println("Probability of cluster 0 = " + r.get(0));
            System.out.println("Probability of cluster 1 = " + r.get(1));

    

    public static Vector getVector(Point point)
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);
        return v;
    

输出:

2 , 2 belongs to 0
1 , 0 belongs to 0
9 , 8 belongs to 1
8 , 8 belongs to 1
0 , 1 belongs to 0
0 , 0 belongs to 0
1 , 1 belongs to 0
9 , 9 belongs to 1
8 , 9 belongs to 1
0:2.470723149516907E-6,1:0.9999975292768505
ans = 
no of categories = 2
no of features = 3
Probability of cluster 0 = 2.470723149516907E-6
Probability of cluster 1 = 0.9999975292768505

请注意,我在 SimpleClassifier 类之外定义了 Point 类,但这只是为了使代码更具可读性,并不是必需的。

看看当你改变学习率时会发生什么。阅读有关交叉验证的说明,以了解如何选择学习率。

Learning Rate => Probability of cluster 0
0.001 => 0.4991116089
0.01 => 0.492481585
0.1 => 0.469961472
1 => 0.5327745322
10 => 0.9745740393
100 => 0
1000 => 0

选择学习率:

    运行随机梯度下降是很常见的,就像我们从一个固定的学习率 α 开始,慢慢地让学习率 α 降低到零一样 算法运行,也可以保证参数收敛到 全局最小值,而不是仅仅围绕最小值振荡。 在这种情况下,当我们使用常数 α 时,您可以进行初始选择,运行梯度下降并观察成本函数,并相应地调整学习率。说明here

【讨论】:

你能分享交叉验证注释的链接吗,你提到的解释如何选择训练率? 嗨@mucaho,我已经编辑了我的答案以添加它。关于 ML 的其他说明,我会推荐 cs229.stanford.edu/materials.html 您说的是P(0)=0.9999 for class 0,但您的控制台输出显示Probability of cluster 0 = 2.470723149516907E-6Probability of cluster 1 = 0.9999975292768505。我验证了输出,在我的机器上是一样的。我错过了什么吗?【参考方案2】:

我认为我认为您的分类示例可能存在问题

使用OnlineLogisticRegression 训练的默认值(learningRate 等...) 引入恒定偏差(它只是另一个具有恒定值1 的预测变量) Shuffle 训练数据(不要先提供第 1 个集群对应的训练数据,然后再提供给第 2 个集群的数据) 显着增加训练数据量

有关此潜在问题的更多详细信息,请参阅书籍Mahout in Action。

“修复”潜在问题后的

结果: 测试点&lt;0.5, 0.5&gt; 被分类到cluster 0 的概率约为。 0.89 在多次运行中始终如一。 这听起来像是一个合理的输出,因为原点附近的其他点(用于训练模型)也属于cluster 0

代码

public class SimpleClassifier 

    public static class Point 
        public int x;
        public int y;

        public Point(int x, int y) 
            this.x = x;
            this.y = y;
        

        @Override
        public boolean equals(Object arg0) 
            Point p = (Point) arg0;
            return ((this.x == p.x) && (this.y == p.y));
        

        @Override
        public String toString() 
            // TODO Auto-generated method stub
            return this.x + " , " + this.y;
        
    

    public static void main(String[] args) 

        Map<Point, Integer> points = new HashMap<Point, Integer>();

        points.put(new Point(0, 0), 0);
        points.put(new Point(1, 1), 0);
        points.put(new Point(1, 0), 0);
        points.put(new Point(0, 1), 0);
        points.put(new Point(2, 2), 0);


        points.put(new Point(8, 8), 1);
        points.put(new Point(8, 9), 1);
        points.put(new Point(9, 8), 1);
        points.put(new Point(9, 9), 1);


        OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(2, 3, new L1());

        System.out.println("training model  \n");
        for (int i=0; i<100; i++) 
            List<Point> randomPoints = new ArrayList<>(points.keySet());
            Collections.shuffle(randomPoints);
            for (Point point : randomPoints) 
                Vector v = getVector(point);
                System.out.println(point + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            
        
        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(3);
        v.set(0, 0.5);
        v.set(1, 0.5);
        v.set(2, 1);

        Vector r = learningAlgo.classify(v);
        System.out.println(r);

        System.out.println("ans = ");
        System.out.println("no of categories = " + learningAlgo.numCategories());
        System.out.println("no of features = " + learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + (1.0d - r.get(0)));
        System.out.println("Probability of cluster 1 = " + r.get(0));

    

    public static Vector getVector(Point point) 
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);

        return v;
    

【讨论】:

次要问题 - 不要以改变问题的方式编辑问题,即添加额外的解释。该信息可以包含在您的答案中(就像您在此处所做的那样)或对问题的评论。 @admdrew 好的,我认为附加解释是问题的一部分(例如,用户提到控制台输出错误,但他没有提到他希望看到的内容 - 我只是提取了他的期望来自源代码,因此其他人不需要浏览源代码即可看到他预期的控制台输出)

以上是关于Mahout - 简单的分类问题的主要内容,如果未能解决你的问题,请参考以下文章

使用 Apache Mahout 对数据进行分类

Mahout实现的分类算法,两个例子,预测期望的目标变量

如何解析 CSV 文件,以便可以被 Mahout 分类

将 mahout 随机森林分类输出转换为可读

使用 Mahout 进行朴素贝叶斯分类的情感分析

用于情绪分析的 Mahout