SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)

Posted 汪本成

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)相关的知识,希望对你有一定的参考价值。

1、随机梯度下降

首先介绍一下随机梯度下降算法:

1.1、代码一:

package mllib


import org.apache.log4j.Level, Logger
import org.apache.spark.SparkContext, SparkConf

import scala.collection.mutable.HashMap

/**
  * 随机梯度下降算法
  * Created by 汪本成 on 2016/8/7.
  */
object SGD 

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //程序入口
  val conf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(this.getClass().getSimpleName()
    .filter(!_.equals('$')))

  println(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  //创建存储数据集HashMap集合
  val data = new HashMap[Int, Int]()
  //生成数据集内容
  def getData(): HashMap[Int, Int] = 
    for(i <- 1 to 50) 
      data += (i -> (2 * i))  //写入公式y=2x
    
    data
  

  //假设a=0
  var a: Double = 0
  //设置步进系数
  var b: Double = 0.1

  //设置迭代公式
  def sgd(x: Double, y: Double) = 
    a = a - b * ((a * x) - y)
  

  def main(args: Array[String]) 
    //获取数据集
    val dataSource = getData()
    println("data: ")
    dataSource.foreach(each => println(each + " "))
    println("\\nresult: ")
    var num = 1
    //开始迭代
    dataSource.foreach(myMap => 
      println(num + ":" + a + "("+myMap._1+","+myMap._2+")")
      sgd(myMap._1, myMap._2)
      num = num + 1
    )
    //显示结果
    println("最终结果a " + a)
  


结果请大家自己验证。

2、线性回归

2.1、数据

首先是做下小数据集的实验,测试的公式在代码中有说明,实验数据如下:

5,1 1
7,2 1
10,2 2
9,3 2
11,4 1
19,5 3
18,6 2

2.2、代码二:

package mllib

import org.apache.log4j.Level, Logger
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LinearRegressionWithSGD, LabeledPoint
import org.apache.spark.SparkContext, SparkConf

/**
  * 线性回归1-小数据集
  * 公式:f(x) = ax1 + bx2
  * Created by 汪本成 on 2016/8/6.
  */
object LinearRegression1 

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //程序入口
  val conf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))
  println(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  def main(args: Array[String]) 
    //获取数据集路径
    val data = sc.textFile("G:\\\\MLlibData\\\\lpsa2.txt")
    //处理数据集
    val parsedData = data.map  line =>
      val parts = line.split(',')
      //转化数据格式
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    .cache()
    //建立模型
    val numiteartor = 100
    val stepSize = 0.1
    val model = LinearRegressionWithSGD.train(parsedData, numiteartor, stepSize)
    //通过模型预测模型
    val result = model.predict(Vectors.dense(2, 1))
    println("model weights:")
    //计算两个系数,并以向量形式保存
    println(model.weights)
    println(result)
    sc.stop()
  



3、回归曲线

回归曲线这块我们不仅预测结果和真实结果,还要计算回归曲线的MSE。

3.1、数据

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
-0.1625189,-2.16691708463163 -0.807993896938655 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.3715636,-0.507874475300631 -0.458834049396776 -0.250631301876899 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.7654678,-2.03612849966376 -0.933954647105133 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.8544153,-0.557312518810673 -0.208756571683607 -0.787896192088153 0.990146852537193 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,-0.929360463147704 -0.0578991819441687 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,-2.28833047634983 -0.0706369432557794 -0.116315079324086 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,0.223498042876113 -1.41471935455355 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
1.3480731,0.107785900236813 -1.47221551299731 0.420949810887169 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.687186906466865
1.446919,0.162180092313795 -1.32557369901905 0.286633588334355 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.4701758,-1.49795329918548 -0.263601072284232 0.823898478545609 0.788388310173035 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
1.4929041,0.796247055396743 0.0476559407005752 0.286633588334355 -1.02470580167082 -0.522940888712441 0.394013435896129 -1.04215728919298 -0.864466507337306
1.5581446,-1.62233848461465 -0.843294091975396 -3.07127197548598 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.5993876,-0.990720665490831 0.458513517212311 0.823898478545609 1.07379746308195 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.6389967,-0.171901281967138 -0.489197399065355 -0.65357996953534 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.6956156,-1.60758252338831 -0.590700340358265 -0.65357996953534 -0.619561070667254 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.7137979,0.366273918511144 -0.414014962912583 -0.116315079324086 0.232904453212813 -0.522940888712441 0.971228997418125 0.342627053981254 1.26288870310799
1.8000583,-0.710307384579833 0.211731938156277 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.442797990776478 0.342627053981254 1.61744790484887
1.8484548,-0.262791728113881 -1.16708345615721 0.420949810887169 0.0846342590816532 -0.522940888712441 0.163172393491611 0.342627053981254 1.97200710658975
1.8946169,0.899043117369237 -0.590700340358265 0.152317365781542 -1.02470580167082 -0.522940888712441 1.28643254437683 -1.04215728919298 -0.864466507337306
1.9242487,-0.903451690500615 1.07659722048274 0.152317365781542 1.28380453408541 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
2.008214,-0.0633337899773081 -1.38088970920094 0.958214701098423 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.0476928,-1.15393789990757 -0.961853075398404 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
2.1575593,0.0620203721138446 0.0657973885499142 1.22684714620405 -0.468824786336838 -0.522940888712441 1.31421001659859 1.72741139715549 -0.332627704725983
2.1916535,-0.75731027755674 -2.92717970468456 0.018001143228728 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.2137539,1.11226993252773 1.06484916245061 0.555266033439982 0.877691038550889 1.89254797819741 1.43890404648442 0.342627053981254 0.376490698755783
2.2772673,-0.468768642850639 -1.43754788774533 -1.05652863719378 0.576050411655607 -0.522940888712441 0.0120483832567209 0.342627053981254 -0.687186906466865
2.2975726,-0.618884859896728 -1.1366360750781 -0.519263746982526 -1.02470580167082 -0.522940888712441 -0.863171185425945 3.11219574032972 1.97200710658975
2.3272777,-0.651431999123483 0.55329161145762 -0.250631301876899 1.11210019001038 -0.522940888712441 -0.179808625688859 -1.04215728919298 -0.864466507337306
2.5217206,0.115499102435224 -0.512233676577595 0.286633588334355 1.13650173283446 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.155348103855541
2.5533438,0.266341329949937 -0.551137885443386 -0.384947524429713 0.354857790686005 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.5687881,1.16902610257751 0.855491905752846 2.03274448152093 1.22628985326088 1.89254797819741 2.02833774827712 3.11219574032972 2.68112551007152
2.6567569,-0.218972367124187 0.851192298581141 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 0.908329501367106
2.677591,0.263121415733908 1.4142681068416 0.018001143228728 1.35980653053822 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.7180005,-0.0704736333296423 1.52000996595417 0.286633588334355 1.39364261119802 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.7942279,-0.751957286017338 0.316843561689933 -1.99674219506348 0.911736065044475 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.8063861,-0.685277652430997 1.28214038482516 0.823898478545609 0.232904453212813 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
2.8124102,-0.244991501432929 0.51882005949686 -0.384947524429713 0.823246560137838 -0.522940888712441 -0.863171185425945 0.342627053981254 0.553770299626224
2.8419982,-0.75731027755674 2.09041984898851 1.22684714620405 1.53428167116843 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.8535925,1.20962937075363 -0.242882661178889 1.09253092365124 -1.02470580167082 -0.522940888712441 1.24263233939889 3.11219574032972 2.50384590920108
2.9204698,0.570886990493502 0.58243883987948 0.555266033439982 1.16006887775962 -0.522940888712441 1.07357183940747 0.342627053981254 1.61744790484887
2.9626924,0.719758684343624 0.984970304132004 1.09253092365124 1.52137230773457 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.509907305596424
2.9626924,-1.52406140158064 1.81975700990333 0.689582255992796 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.9729753,-0.132431544081234 2.68769877553723 1.09253092365124 1.53428167116843 -0.522940888712441 -0.442797990776478 0.342627053981254 -0.687186906466865
3.0130809,0.436161292804989 -0.0834447307428255 -0.519263746982526 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
3.0373539,-0.161195191984091 -0.671900359186746 1.7641120364153 1.13650173283446 -0.522940888712441 -0.863171185425945 0.342627053981254 0.0219314970149
3.2752562,1.39927182372944 0.513852869452676 0.689582255992796 -1.02470580167082 1.89254797819741 1.49394503405693 0.342627053981254 -0.155348103855541
3.3375474,1.51967002306341 -0.852203755696565 0.555266033439982 -0.104527297798983 1.89254797819741 1.85927724828569 0.342627053981254 0.908329501367106
3.3928291,0.560725834706224 1.87867703391426 1.09253092365124 1.39364261119802 -0.522940888712441 0.486423065822545 0.342627053981254 1.26288870310799
3.4355988,1.00765532502814 1.69426310090641 1.89842825896812 1.53428167116843 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.509907305596424
3.4578927,1.10152996153577 -0.10927271844907 0.689582255992796 -1.02470580167082 1.89254797819741 1.97630171771485 0.342627053981254 1.61744790484887
3.5160131,0.100001934217311 -1.30380956369388 0.286633588334355 0.316555063757567 -0.522940888712441 0.28786643052924 0.342627053981254 0.553770299626224
3.5307626,0.987291634724086 -0.36279314978779 -0.922212414640967 0.232904453212813 -0.522940888712441 1.79270085261407 0.342627053981254 1.26288870310799
3.5652984,1.07158528137575 0.606453149641961 1.7641120364153 -0.432854616994416 1.89254797819741 0.528504607720369 0.342627053981254 0.199211097885341
3.5876769,0.180156323255198 0.188987436375017 -0.519263746982526 1.09956763075594 -0.522940888712441 0.708239632330506 0.342627053981254 0.199211097885341
3.6309855,1.65687973755377 -0.256675483533719 0.018001143228728 -1.02470580167082 1.89254797819741 1.79270085261407 0.342627053981254 1.26288870310799
3.6800909,0.5720085322365 0.239854450210939 -0.787896192088153 1.0605418233138 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
3.7123518,0.323806133438225 -0.606717660886078 -0.250631301876899 -1.02470580167082 1.89254797819741 0.342907418101747 0.342627053981254 0.199211097885341
3.9843437,1.23668206715898 2.54220539083611 0.152317365781542 -1.02470580167082 1.89254797819741 1.89037692416194 0.342627053981254 1.26288870310799
3.993603,0.180156323255198 0.154448192444669 1.62979581386249 0.576050411655607 1.89254797819741 0.708239632330506 0.342627053981254 1.79472750571931
4.029806,1.60906277046565 1.10378605019827 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
4.1295508,1.0036214996026 0.113496885050331 -0.384947524429713 0.860016436332751 1.89254797819741 -0.863171185425945 0.342627053981254 -0.332627704725983
4.3851468,1.25591974271076 0.577607033774471 0.555266033439982 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
4.6844434,2.09650591351268 0.625488598331018 -2.66832330782754 -1.02470580167082 1.89254797819741 1.67954222367555 0.342627053981254 0.553770299626224
5.477509,1.30028987435881 0.338383613253713 0.555266033439982 1.00481276295349 1.89254797819741 1.24263233939889 0.342627053981254 1.97200710658975

3.2、代码三:

package mllib

import java.text.SimpleDateFormat
import java.util.Date

import org.apache.log4j.Level, Logger
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LinearRegressionWithSGD, LabeledPoint
import org.apache.spark.SparkContext, SparkConf

/**
  * 计算回归曲线的MSE
  * 对多组数据进行modeltraining,然后再利用modelpredict具体的值
  * 过程中有输出model的权重
  * 公式:f(x)=a1X1+a2X2+a3X3+……
  * Created by 汪本成 on 2016/8/7.
  */
object LinearRegression2 

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //程序入口
  val conf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))
  println(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  def main(args: Array[String]) 
    //获取数据集路径
    val data = sc.textFile("G:\\\\MLlibData\\\\lpsa.data", 1)
    //处理数据集
    val parsedData = data.map line =>
      val parts = line.split(",")
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    



    //建立模型
    //建立model的数据和predict的数据没有分开
    val numIterations = 100
    val model = LinearRegressionWithSGD.train(parsedData, numIterations, 0.1)
    //for (i <- parsedData) println(i.label + ":" + i.features)

    //获取真实值与预测值
    val valuesAndPreds = parsedData.map  point =>
      //对系数进行预测
      val  prediction = model.predict(point.features)
      //按格式进行储存
      (point.label, prediction)
    

    //打印权重
    var weights = model.weights
    println("model.weights" + weights)

    //save as file
    val isString = new SimpleDateFormat("yyyyMMddHHmmssSSS").format(new Date())
    val path = "G:\\\\MLlibData\\\\saveFile\\\\" + isString + "\\\\results"
    valuesAndPreds.saveAsTextFile(path)
    val MSE = valuesAndPreds.map case(v, p) => math.pow((v - p), 2)
      .reduce(_ + _ ) / valuesAndPreds.count
    println("训练的数据集的均方误差是: " + MSE)
    sc.stop()

  


注意:MLlib中的线性回归比较适合做一元线性回归而非多元线性回归,当回归系数比较多时,算法产生的过拟合现象较为严重。

4、逻辑回归

4.1、数据

这里包括了我写的意愿逻辑回归和多元逻辑回归,数据用的是spark工程下的sample_libsvm_data.txt文件和我自己弄的logisticRegression1.data,内容如下:

1|2
1|3
1|4
1|5
1|6
0|7
0|8
0|9
0|10
0|11

4.2、代码四

package mllib

import org.apache.log4j.Level, Logger
import org.apache.spark.mllib.classification.LogisticRegressionModel, LogisticRegressionWithSGD
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector, Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext, SparkConf

/**
  * 逻辑回归
  * Created by 汪本成 on 2016/8/7.
  */
object LogisticRegression 

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)


  val conf = new SparkConf()
    .setMaster("local[4]")
    .setAppName(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  var logisticRegression = new LogisticRegression

  //一元逻辑回归数据集
  val LR1_PATH = "file\\\\data\\\\mllib\\\\input\\\\regression\\\\logisticRegression1.data"
  //多元逻辑回归数据集
  val LR2_PATH = "file\\\\data\\\\mllib\\\\input\\\\regression\\\\sample_libsvm_data.txt"

  val data = sc.textFile(LR1_PATH)
  val svmData = MLUtils.loadLibSVMFile(sc, LR2_PATH)

  //分割数据集
  val splits = svmData.randomSplit(Array(0.6, 0.4), seed = 11L)
  val parsedData_SVM = splits(0)
  val parsedTest_SVM = splits(1)


  //转化数据格式
  val parsedData = data.map  line =>
    val parts = line.split('|')
    LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
  .cache()

  //建立模型
  val model = LogisticRegressionWithSGD.train(parsedData, 50)
  val svmModel = LogisticRegressionWithSGD.train(parsedData_SVM, 50)


  //创建测试值
  val target = Vectors.dense(-1)

  //根据模型计算测试值结果
  val predict = model.predict(target)

  //计算多元逻辑回归的测试值,并存储测试和预测值
  val predict_svm = logisticRegression.predictAndLabels(parsedTest_SVM, svmModel)

  //创建验证类
  val metrics = new MulticlassMetrics(predict_svm)

  //计算验证值
  val precision = metrics.precision

  def main(args: Array[String]) 
    println("一元逻辑回归:")
    parsedData.foreach(println)
    //打印权重
    println("权重: " + model.weights)
    println(predict)
    println(model.predict(Vectors.dense(10)))

    println("*************************************************************")

    println("多元逻辑回归:")
    println("svmData记录数:" + svmData.count())
    println("parsedData_SVM" + parsedData_SVM.count())
    println("parsedTest_SVM" + parsedTest_SVM.count())
    println("Precision = " + precision) //打印验证值
    predict_svm.take(10).foreach(println)
    println("权重: " + svmModel.weights)
    println("weights 个数是: " + svmModel.weights.size)
    //打印weight不为0个数
    println("weights不为0的个数是: " + model.weights.toArray.filter(_ != 0).size)
    sc.stop()
  



class LogisticRegression 

  /**
    *
    * @param data  svmData
    * @param model LogisticRegressionModel
    * @return
    */
  def predictAndLabels(
  data: RDD[LabeledPoint],
  model: LogisticRegressionModel):RDD[(Double, Double)]= 
    val parsedData = data.map 
      point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
    
    parsedData
  


运行结果请读者自己实验,我就不截图了

以上是关于SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)的主要内容,如果未能解决你的问题,请参考以下文章