apache spark mllib中的逻辑回归 - mnist
Posted
技术标签:
【中文标题】apache spark mllib中的逻辑回归 - mnist【英文标题】:Logistic regression in apache spark mllib - mnist 【发布时间】:2018-04-06 22:33:16 【问题描述】:我正在尝试使用逻辑回归创建分类器 根据像素(特征)的值预测正确的数字(标签)。 我在 java 中使用 Apache Spark,并且在将 mnist 数据库中的数据转换为 libsvm 格式后,我正在使用它, 这是我的代码:
package ml;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;
public class MNIST5
static String trainImagesPath = "train-images.idx3-ubyte";
static String trainLabelsPath = "train-labels.idx1-ubyte";
static String testImagesPath = "t10k-images.idx3-ubyte";
static String testLabelsPath = "t10k-labels.idx1-ubyte";
static SparkConf conf = new SparkConf()
.setMaster("local")
.setAppName("Machine learning - MNIST Example");
static SparkContext sc = SparkContext.getOrCreate(conf);
public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException
mnist_spark_logistic_regression();
//saveMnistDataLibsvmFormat();
static void mnist_spark_logistic_regression()
long t;
System.out.println("Loading training data ...");
t = System.currentTimeMillis();
JavaRDD<LabeledPoint> trainData = MLUtils.loadLibSVMFile(sc, "mnist-train-data.txt").toJavaRDD();
System.out.println(System.currentTimeMillis()-t+" ms"); // 6661 ms
System.out.println("Training logistic regression classifier ...");
t = System.currentTimeMillis();
// Run training algorithm to build the model.
LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS()
.setNumClasses(10);
//lr.optimizer().setUpdater(new L1Updater());
LogisticRegressionModel model = lr.run(trainData.rdd());
System.out.println(System.currentTimeMillis()-t+" ms"); // 1951 ms
// print weights and intercept
System.out.println("numClasses: "+model.numClasses());
System.out.println("numFeatures: "+model.numFeatures());
System.out.println("Weights: "+model.weights());
System.out.println("Wlength: "+model.weights().size());
System.out.println("Intercept: "+model.intercept());
System.out.println("Loading testing data ...");
t = System.currentTimeMillis();
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, "mnist-test-data.txt").toJavaRDD();
System.out.println(System.currentTimeMillis()-t+" ms"); // 11356 ms
System.out.println("Compute raw scores on the test set ...");
t = System.currentTimeMillis();
// Compute raw scores on the test set.
JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(
(p) ->
return new Tuple2<>(model.predict(p.features()), p.label());
);
System.out.println(System.currentTimeMillis()-t+" ms"); // 47 ms
System.out.println("Iterate ...");
t = System.currentTimeMillis();
JavaRDD<Integer> wyw = testData.map(new Function<LabeledPoint, Integer>()
@Override
public Integer call(LabeledPoint t1) throws Exception
double yb = model.predict(t1.features());
if(yb==t1.label())
System.out.println("label: "+t1.label()+", predicted: "+yb);
return 0;
);
wyw.collect();
System.out.println(System.currentTimeMillis()-t+" ms");
System.out.println("Evaluating ...");
t = System.currentTimeMillis();
// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
double accuracy = metrics.accuracy();
System.out.println("Accuracy = " + accuracy); // 0.098
System.out.println(System.currentTimeMillis()-t+" ms"); // 1108 ms
// Save and load model
model.save(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
System.out.println(sameModel);
static ArrayList<LabeledPoint> getData(String imagesPath, String labelsPath)
JavaRDD<LabeledPoint> data;
ArrayList<LabeledPoint> lpts = new ArrayList<>();
FileInputStream inImage = null;
FileInputStream inLabel = null;
try
inImage = new FileInputStream(imagesPath);
inLabel = new FileInputStream(labelsPath);
int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfRows = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
int numberOfPixels = numberOfRows * numberOfColumns;
double[] imgPixels = new double[numberOfPixels];
for(int i = 0; i < numberOfImages; i++)
//if(i % 100 == 0) System.out.println("Number of images extracted: " + i);
for(int p = 0; p < numberOfPixels; p++)
imgPixels[p] = inImage.read();
int label = inLabel.read();
LabeledPoint lp = LabeledPoint.apply(label, Vectors.dense(imgPixels));
lpts.add(lp);
catch (FileNotFoundException e) e.printStackTrace();
catch (IOException e) e.printStackTrace();
finally
if (inImage != null)
try
inImage.close();
catch (IOException e) e.printStackTrace();
if (inLabel != null)
try
inLabel.close();
catch (IOException e) e.printStackTrace();
return lpts;
static JavaRDD<LabeledPoint> loadData(String imagesPath, String labelsPath)
ArrayList<LabeledPoint> lpts = getData(imagesPath, labelsPath);
JavaSparkContext jsc = new JavaSparkContext(sc);
JavaRDD<LabeledPoint> data = jsc.parallelize(lpts);
return data;
static void saveMnistDataLibsvmFormat() throws FileNotFoundException, UnsupportedEncodingException
ArrayList<LabeledPoint> data = getData(testImagesPath, testLabelsPath);
PrintWriter writer = new PrintWriter("mnist-test-data.txt", "UTF-8");
for(LabeledPoint lp : data)
StringBuilder s = new StringBuilder();
s.append(lp.label()).append(" ");
int i;
double[] arr = lp.features().toArray();
for(i=0;i<arr.length-1;i++)
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]).append(" ");
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]);
writer.println(s.toString());
// writer.println("The first line");
// writer.println("The second line");
writer.close();
ArrayList<LabeledPoint> data2 = getData(trainImagesPath, trainLabelsPath);
PrintWriter writer2 = new PrintWriter("mnist-train-data.txt", "UTF-8");
for(LabeledPoint lp : data2)
StringBuilder s = new StringBuilder();
s.append(lp.label()).append(" ");
int i;
double[] arr = lp.features().toArray();
for(i=0;i<arr.length-1;i++)
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]).append(" ");
if(arr[i]!=0)
s.append(i+1).append(":").append(arr[i]);
writer2.println(s.toString());
// writer.println("The first line");
// writer.println("The second line");
writer2.close();
权重的值都为零, 我不明白为什么?请帮忙,谢谢。
【问题讨论】:
【参考方案1】:什么意思
`if(yb==t1.label())
System.out.println("label: "+t1.label()+", predicted: "+yb);
return 0;`
它总是返回 0。
【讨论】:
我只是写它来看看预测好的结果,所有的预测都等于零,因为所有的权重都是零。以上是关于apache spark mllib中的逻辑回归 - mnist的主要内容,如果未能解决你的问题,请参考以下文章
Apache Spark MLlib:如何从 PMML 导入模型