apache spark mllib中的逻辑回归 - mnist

Posted

技术标签:

【中文标题】apache spark mllib中的逻辑回归 - mnist【英文标题】:Logistic regression in apache spark mllib - mnist 【发布时间】:2018-04-06 22:33:16 【问题描述】:

我正在尝试使用逻辑回归创建分类器 根据像素(特征)的值预测正确的数字(标签)。 我在 java 中使用 Apache Spark,并且在将 mnist 数据库中的数据转换为 libsvm 格式后,我正在使用它, 这是我的代码:

package ml;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;


public class MNIST5 

    static String trainImagesPath = "train-images.idx3-ubyte";
    static String trainLabelsPath = "train-labels.idx1-ubyte";
    static String testImagesPath = "t10k-images.idx3-ubyte";
    static String testLabelsPath = "t10k-labels.idx1-ubyte";


    static SparkConf conf = new SparkConf()
            .setMaster("local")
            .setAppName("Machine learning - MNIST Example");

    static SparkContext sc = SparkContext.getOrCreate(conf);

    public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException 

        mnist_spark_logistic_regression();
        //saveMnistDataLibsvmFormat();

    

    static void mnist_spark_logistic_regression()

        long t;

        System.out.println("Loading training data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> trainData = MLUtils.loadLibSVMFile(sc, "mnist-train-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 6661 ms


        System.out.println("Training logistic regression classifier ...");
        t = System.currentTimeMillis();
        // Run training algorithm to build the model.         
        LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS()           
            .setNumClasses(10);
        //lr.optimizer().setUpdater(new L1Updater());        
        LogisticRegressionModel model = lr.run(trainData.rdd());

        System.out.println(System.currentTimeMillis()-t+" ms"); // 1951 ms
        // print weights and intercept
        System.out.println("numClasses: "+model.numClasses());
        System.out.println("numFeatures: "+model.numFeatures());
        System.out.println("Weights: "+model.weights());
        System.out.println("Wlength: "+model.weights().size());
        System.out.println("Intercept: "+model.intercept());


        System.out.println("Loading testing data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, "mnist-test-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 11356 ms


        System.out.println("Compute raw scores on the test set ...");
        t = System.currentTimeMillis();
        // Compute raw scores on the test set.
        JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(
            (p) -> 
                return new Tuple2<>(model.predict(p.features()), p.label());
            
        );
        System.out.println(System.currentTimeMillis()-t+" ms"); // 47 ms


        System.out.println("Iterate ...");
        t = System.currentTimeMillis();
        JavaRDD<Integer> wyw = testData.map(new Function<LabeledPoint, Integer>() 
            @Override
            public Integer call(LabeledPoint t1) throws Exception 
                double yb = model.predict(t1.features());

                if(yb==t1.label())
                    System.out.println("label: "+t1.label()+", predicted: "+yb);
                return 0;
            
        );
        wyw.collect();
        System.out.println(System.currentTimeMillis()-t+" ms");


        System.out.println("Evaluating ...");
        t = System.currentTimeMillis();
        // Get evaluation metrics.
        MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
        double accuracy = metrics.accuracy();
        System.out.println("Accuracy = " + accuracy); // 0.098
        System.out.println(System.currentTimeMillis()-t+" ms"); // 1108 ms


        // Save and load model
        model.save(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        System.out.println(sameModel);

    


    static ArrayList<LabeledPoint> getData(String imagesPath, String labelsPath)        

        JavaRDD<LabeledPoint> data;
        ArrayList<LabeledPoint> lpts = new ArrayList<>();

        FileInputStream inImage = null;
        FileInputStream inLabel = null;

        try 
            inImage = new FileInputStream(imagesPath);
            inLabel = new FileInputStream(labelsPath);

            int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfRows  = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());

            int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
            int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());

            int numberOfPixels = numberOfRows * numberOfColumns;
            double[] imgPixels = new double[numberOfPixels];

            for(int i = 0; i < numberOfImages; i++) 

                //if(i % 100 == 0) System.out.println("Number of images extracted: " + i);

                for(int p = 0; p < numberOfPixels; p++) 
                    imgPixels[p] = inImage.read();
                

                int label = inLabel.read();

                LabeledPoint lp = LabeledPoint.apply(label, Vectors.dense(imgPixels));
                lpts.add(lp);

            

         
        catch (FileNotFoundException e)  e.printStackTrace();  
        catch (IOException e)  e.printStackTrace();  
        finally 
            if (inImage != null) 
                try 
                    inImage.close();
                 catch (IOException e)  e.printStackTrace(); 
            
            if (inLabel != null) 
                try 
                    inLabel.close();
                 catch (IOException e)  e.printStackTrace(); 
            
        

        return lpts;
    

    static JavaRDD<LabeledPoint> loadData(String imagesPath, String labelsPath)        

        ArrayList<LabeledPoint> lpts = getData(imagesPath, labelsPath);

        JavaSparkContext jsc = new JavaSparkContext(sc);
        JavaRDD<LabeledPoint> data = jsc.parallelize(lpts);

        return data;
    

    static void saveMnistDataLibsvmFormat() throws FileNotFoundException, UnsupportedEncodingException

        ArrayList<LabeledPoint> data = getData(testImagesPath, testLabelsPath);        
        PrintWriter writer = new PrintWriter("mnist-test-data.txt", "UTF-8");
        for(LabeledPoint lp : data)
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer.println(s.toString());
        
        // writer.println("The first line");
        // writer.println("The second line");
        writer.close();

        ArrayList<LabeledPoint> data2 = getData(trainImagesPath, trainLabelsPath);
        PrintWriter writer2 = new PrintWriter("mnist-train-data.txt", "UTF-8");
        for(LabeledPoint lp : data2)
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer2.println(s.toString());
        
        // writer.println("The first line");
        // writer.println("The second line");
        writer2.close();

    


权重的值都为零, 我不明白为什么?请帮忙,谢谢。

【问题讨论】:

【参考方案1】:

什么意思

            `if(yb==t1.label())
                System.out.println("label: "+t1.label()+", predicted: "+yb);
            return 0;`

它总是返回 0。

【讨论】:

我只是写它来看看预测好的结果,所有的预测都等于零,因为所有的权重都是零。

以上是关于apache spark mllib中的逻辑回归 - mnist的主要内容,如果未能解决你的问题,请参考以下文章

MlLib--逻辑回归笔记

Apache Spark MLlib:如何从 PMML 导入模型

资料推荐:Spark-mllib 源码分析之逻辑回归

Spark MLlib速成宝典模型篇02逻辑斯谛回归Logistic回归(Python版)

Spark MLlib 机器学习

Spark MLlib 机器学习