Java中梯度下降逻辑回归的实现
Posted
技术标签:
【中文标题】Java中梯度下降逻辑回归的实现【英文标题】:Implementation of Logistic regression with Gradient Descent in Java 【发布时间】:2015-05-03 15:54:36 【问题描述】:我已经在 Java 中实现了带有梯度下降的逻辑回归。好像效果不太好(没有正确分类记录;y=1的概率很大。)不知道我的实现是否正确。代码我已经翻了好几遍了,还是不行找到任何错误。我一直在关注 Andrew Ng 在 Course Era 上的机器学习教程。我的 Java 实现有 3 个类。即:
-
DataSet.java : 读取数据集
Instance.java:有两个成员:1. double[] x 和 2. double label
Logistic.java :这是使用梯度下降实现逻辑回归的主要类。
这是我的成本函数:
J(Θ) = (- 1/m ) [Σmi=1 y(i) log( hΘ( x(i) ) ) + (1 - y(i) ) log(1 - hΘ ( x(i)) )]
对于上面的成本函数,这是我的梯度下降算法:重复(
Θj := Θj - α Σmi=1 ( hΘ ( x(i)) - y(i) ) x(i)j
(同时更新所有 Θj ))
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Logistic
/** the learning rate */
private double alpha;
/** the weight to learn */
private double[] theta;
/** the number of iterations */
private int ITERATIONS = 3000;
public Logistic(int n)
this.alpha = 0.0001;
theta = new double[n];
private double sigmoid(double z)
return (1 / (1 + Math.exp(-z)));
public void train(List<Instance> instances)
double[] temp = new double[3];
//Gradient Descent algorithm for minimizing theta
for(int i=1;i<=ITERATIONS;i++)
for(int j=0;j<3;j++)
temp[j]=theta[j] - (alpha * sum(j,instances));
//simulataneous updates of theta
for(int j=0;j<3;j++)
theta[j] = temp[j];
System.out.println(Arrays.toString(theta));
private double sum(int j,List<Instance> instances)
double[] x;
double prediction,sum=0,y;
for(int i=0;i<instances.size();i++)
x = instances.get(i).getX();
y = instances.get(i).getLabel();
prediction = classify(x);
sum+=((prediction - y) * x[j]);
return (sum/instances.size());
private double classify(double[] x)
double logit = .0;
for (int i=0; i<theta.length;i++)
logit += (theta[i] * x[i]);
return sigmoid(logit);
public static void main(String... args) throws FileNotFoundException
//DataSet is a class with a static method readDataSet which reads the dataset
// Instance is a class with two members: double[] x, double label y
// x contains the features and y is the label.
List<Instance> instances = DataSet.readDataSet("data.txt");
// 3 : number of theta parameters corresponding to the features x
// x0 is always 1
Logistic logistic = new Logistic(3);
logistic.train(instances);
//Test data
double[]x = new double[3];
x[0]=1;
x[1]=45;
x[2] = 85;
System.out.println("Prob: "+logistic.classify(x));
谁能告诉我我做错了什么? 提前致谢! :)
【问题讨论】:
我认为您需要首先确定您遇到的是 Java 问题还是机器学习问题。您的 Java 程序是否实现了预期的功能,无论它是否是正确的功能?您应该能够从单元测试中看出这一点。 你实现了梯度上升,而不是下降。您还需要将总和除以您处理的实例数 - 这就是您的权重爆炸的原因。 @Thomas 抱歉。我正在尝试不同的东西,但我忘了把它改回减号。我做了编辑。即使它是负数,它也没有按预期工作。 但是你确实读到了我之前的评论,对吧?:P 也许你想看看一些更惯用的python代码:***.com/questions/17784587/… 【参考方案1】:在研究逻辑回归时,我花时间详细检查了您的代码。
TLDR
事实上,看起来算法是正确的。
我认为,你有这么多假阴性或假阳性的原因是因为你选择的超参数。
模型训练不足,因此假设欠拟合。
详情
我不得不创建 DataSet
和 Instance
类,因为您没有发布它们,并基于 Cryotherapy 数据集设置训练数据集和测试数据集。
见http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+。
然后,使用您相同的确切代码(用于逻辑回归部分)并通过选择 0.001
的 alpha 率和 100000
的迭代次数,我在测试中得到了 80.64516129032258
% 的准确率数据集,还不错。
我试图通过手动调整这些超参数来获得更好的准确率,但无法获得更好的结果。
我想,在这一点上,一个增强将是实现正则化。
梯度下降公式
在吴恩达关于成本函数和梯度下降的视频中,省略了1/m
项是正确的。
一种可能的解释是1/m
术语包含在alpha
术语中。
或者,也许这只是一个疏忽。
在 6 分 53 秒见 https://www.youtube.com/watch?v=TTdcc21Ko9A&index=36&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=6m53s。
但是,如果您观看 Andrew Ng 的有关正则化和逻辑回归的视频,您会注意到术语 1/m
清楚地出现在公式中。
2 分 19 秒见 https://www.youtube.com/watch?v=IXPgm1e0IOo&index=42&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=2m19s。
【讨论】:
以上是关于Java中梯度下降逻辑回归的实现的主要内容,如果未能解决你的问题,请参考以下文章
python逻辑回归(logistic regression LR) 底层代码实现 BGD梯度下降算法 softmax多分类