Java中梯度下降逻辑回归的实现

Posted

技术标签:

【中文标题】Java中梯度下降逻辑回归的实现【英文标题】:Implementation of Logistic regression with Gradient Descent in Java 【发布时间】:2015-05-03 15:54:36 【问题描述】:

我已经在 J​​ava 中实现了带有梯度下降的逻辑回归。好像效果不太好(没有正确分类记录;y=1的概率很大。)不知道我的实现是否正确。代码我已经翻了好几遍了,还是不行找到任何错误。我一直在关注 Andrew Ng 在 Course Era 上的机器学习教程。我的 Java 实现有 3 个类。即:

    DataSet.java : 读取数据集 Instance.java:有两个成员:1. double[] x 和 2. double label Lo​​gistic.java :这是使用梯度下降实现逻辑回归的主要类。

这是我的成本函数:

J(Θ) = (- 1/m ) [Σmi=1 y(i) log( hΘ( x(i) ) ) + (1 - y(i) ) log(1 - hΘ ( x(i)) )]

对于上面的成本函数,这是我的梯度下降算法:

重复(

Θj := Θj - α Σmi=1 ( hΘ ( x(i)) - y(i) ) x(i)j

(同时更新所有 Θj

)

import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class Logistic 

    /** the learning rate */
    private double alpha;

    /** the weight to learn */
    private double[] theta;

    /** the number of iterations */
    private int ITERATIONS = 3000;

    public Logistic(int n) 
        this.alpha = 0.0001;
        theta = new double[n];
    

    private double sigmoid(double z) 
        return (1 / (1 + Math.exp(-z)));
    

    public void train(List<Instance> instances) 

    double[] temp = new double[3];

    //Gradient Descent algorithm for minimizing theta
    for(int i=1;i<=ITERATIONS;i++)
    
       for(int j=0;j<3;j++)
             
        temp[j]=theta[j] - (alpha * sum(j,instances));
       

       //simulataneous updates of theta  
       for(int j=0;j<3;j++)
       
         theta[j] = temp[j];
       
        System.out.println(Arrays.toString(theta));
    

    

    private double sum(int j,List<Instance> instances)
    
        double[] x;
        double prediction,sum=0,y;


       for(int i=0;i<instances.size();i++)
       
          x = instances.get(i).getX();
          y = instances.get(i).getLabel();
          prediction = classify(x);
          sum+=((prediction - y) * x[j]);
       
         return (sum/instances.size());

    

    private double classify(double[] x) 
        double logit = .0;
        for (int i=0; i<theta.length;i++)  
            logit += (theta[i] * x[i]);
        
        return sigmoid(logit);
    


    public static void main(String... args) throws FileNotFoundException 

      //DataSet is a class with a static method readDataSet which reads the dataset
      // Instance is a class with two members: double[] x, double label y
      // x contains the features and y is the label.

        List<Instance> instances = DataSet.readDataSet("data.txt");
      // 3 : number of theta parameters corresponding to the features x 
      // x0 is always 1   
        Logistic logistic = new Logistic(3);
        logistic.train(instances);

        //Test data
        double[]x = new double[3];
        x[0]=1;
        x[1]=45;
        x[2] = 85;

        System.out.println("Prob: "+logistic.classify(x));


    

谁能告诉我我做错了什么? 提前致谢! :)

【问题讨论】:

我认为您需要首先确定您遇到的是 Java 问题还是机器学习问题。您的 Java 程序是否实现了预期的功能,无论它是否是正确的功能?您应该能够从单元测试中看出这一点。 你实现了梯度上升,而不是下降。您还需要将总和除以您处理的实例数 - 这就是您的权重爆炸的原因。 @Thomas 抱歉。我正在尝试不同的东西,但我忘了把它改回减号。我做了编辑。即使它是负数,它也没有按预期工作。 但是你确实读到了我之前的评论,对吧?:P 也许你想看看一些更惯用的python代码:***.com/questions/17784587/… 【参考方案1】:

在研究逻辑回归时,我花时间详细检查了您的代码。

TLDR

事实上,看起来算法是正确的。

我认为,你有这么多假阴性或假阳性的原因是因为你选择的超参数。

模型训练不足,因此假设欠拟合。

详情

我不得不创建 DataSetInstance 类,因为您没有发布它们,并基于 Cryotherapy 数据集设置训练数据集和测试数据集。 见http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+。

然后,使用您相同的确切代码(用于逻辑回归部分)并通过选择 0.001 的 alpha 率和 100000 的迭代次数,我在测试中得到了 80.64516129032258% 的准确率数据集,还不错。

我试图通过手动调整这些超参数来获得更好的准确率,但无法获得更好的结果。

我想,在这一点上,一个增强将是实现正则化。

梯度下降公式

在吴恩达关于成本函数和梯度下降的视频中,省略了1/m 项是正确的。 一种可能的解释是1/m 术语包含在alpha 术语中。 或者,也许这只是一个疏忽。 在 6 分 53 秒见 https://www.youtube.com/watch?v=TTdcc21Ko9A&index=36&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=6m53s。

但是,如果您观看 Andrew Ng 的有关正则化和逻辑回归的视频,您会注意到术语 1/m 清楚地出现在公式中。 2 分 19 秒见 https://www.youtube.com/watch?v=IXPgm1e0IOo&index=42&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=2m19s。

【讨论】:

以上是关于Java中梯度下降逻辑回归的实现的主要内容,如果未能解决你的问题,请参考以下文章

Python 梯度下降实现逻辑回归

机器学习P6 逻辑回归的 损失函数 以及 梯度下降

机器学习100天(十七):017 逻辑回归梯度下降

机器学习100天(十七):017 逻辑回归梯度下降

python逻辑回归(logistic regression LR) 底层代码实现 BGD梯度下降算法 softmax多分类

关于对率回归的求解,梯度下降和解析解相比有啥特点和优势,为啥?