SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)
Posted 汪本成
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)相关的知识,希望对你有一定的参考价值。
1、随机梯度下降
首先介绍一下随机梯度下降算法:
1.1、代码一:
package mllib import org.apache.log4j.Level, Logger import org.apache.spark.SparkContext, SparkConf import scala.collection.mutable.HashMap /** * 随机梯度下降算法 * Created by 汪本成 on 2016/8/7. */ object SGD //屏蔽不必要的日志显示在终端上 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) //程序入口 val conf = new SparkConf() .setMaster("local[1]") .setAppName(this.getClass().getSimpleName() .filter(!_.equals('$'))) println(this.getClass().getSimpleName().filter(!_.equals('$'))) val sc = new SparkContext(conf) //创建存储数据集HashMap集合 val data = new HashMap[Int, Int]() //生成数据集内容 def getData(): HashMap[Int, Int] = for(i <- 1 to 50) data += (i -> (2 * i)) //写入公式y=2x data //假设a=0 var a: Double = 0 //设置步进系数 var b: Double = 0.1 //设置迭代公式 def sgd(x: Double, y: Double) = a = a - b * ((a * x) - y) def main(args: Array[String]) //获取数据集 val dataSource = getData() println("data: ") dataSource.foreach(each => println(each + " ")) println("\\nresult: ") var num = 1 //开始迭代 dataSource.foreach(myMap => println(num + ":" + a + "("+myMap._1+","+myMap._2+")") sgd(myMap._1, myMap._2) num = num + 1 ) //显示结果 println("最终结果a为 " + a) |
2、线性回归
2.1、数据
首先是做下小数据集的实验,测试的公式在代码中有说明,实验数据如下:
5,1 1
7,2 1
10,2 2
9,3 2
11,4 1
19,5 3
18,6 2
2.2、代码二:
package mllib import org.apache.log4j.Level, Logger import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LinearRegressionWithSGD, LabeledPoint import org.apache.spark.SparkContext, SparkConf /** * 线性回归1-小数据集 * 公式:f(x) = ax1 + bx2 * Created by 汪本成 on 2016/8/6. */ object LinearRegression1 //屏蔽不必要的日志显示在终端上 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) //程序入口 val conf = new SparkConf() .setMaster("local[1]") .setAppName(this.getClass().getSimpleName().filter(!_.equals('$'))) println(this.getClass().getSimpleName().filter(!_.equals('$'))) val sc = new SparkContext(conf) def main(args: Array[String]) //获取数据集路径 val data = sc.textFile("G:\\\\MLlibData\\\\lpsa2.txt") //处理数据集 val parsedData = data.map line => val parts = line.split(',') //转化数据格式 LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) .cache() //建立模型 val numiteartor = 100 val stepSize = 0.1 val model = LinearRegressionWithSGD.train(parsedData, numiteartor, stepSize) //通过模型预测模型 val result = model.predict(Vectors.dense(2, 1)) println("model weights:") //计算两个系数,并以向量形式保存 println(model.weights) println(result) sc.stop() |
3、回归曲线
回归曲线这块我们不仅预测结果和真实结果,还要计算回归曲线的MSE。
3.1、数据
-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
-0.1625189,-2.16691708463163 -0.807993896938655 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.3715636,-0.507874475300631 -0.458834049396776 -0.250631301876899 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.7654678,-2.03612849966376 -0.933954647105133 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
0.8544153,-0.557312518810673 -0.208756571683607 -0.787896192088153 0.990146852537193 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,-0.929360463147704 -0.0578991819441687 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,-2.28833047634983 -0.0706369432557794 -0.116315079324086 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.2669476,0.223498042876113 -1.41471935455355 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
1.3480731,0.107785900236813 -1.47221551299731 0.420949810887169 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.687186906466865
1.446919,0.162180092313795 -1.32557369901905 0.286633588334355 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.4701758,-1.49795329918548 -0.263601072284232 0.823898478545609 0.788388310173035 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
1.4929041,0.796247055396743 0.0476559407005752 0.286633588334355 -1.02470580167082 -0.522940888712441 0.394013435896129 -1.04215728919298 -0.864466507337306
1.5581446,-1.62233848461465 -0.843294091975396 -3.07127197548598 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.5993876,-0.990720665490831 0.458513517212311 0.823898478545609 1.07379746308195 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.6389967,-0.171901281967138 -0.489197399065355 -0.65357996953534 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.6956156,-1.60758252338831 -0.590700340358265 -0.65357996953534 -0.619561070667254 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
1.7137979,0.366273918511144 -0.414014962912583 -0.116315079324086 0.232904453212813 -0.522940888712441 0.971228997418125 0.342627053981254 1.26288870310799
1.8000583,-0.710307384579833 0.211731938156277 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.442797990776478 0.342627053981254 1.61744790484887
1.8484548,-0.262791728113881 -1.16708345615721 0.420949810887169 0.0846342590816532 -0.522940888712441 0.163172393491611 0.342627053981254 1.97200710658975
1.8946169,0.899043117369237 -0.590700340358265 0.152317365781542 -1.02470580167082 -0.522940888712441 1.28643254437683 -1.04215728919298 -0.864466507337306
1.9242487,-0.903451690500615 1.07659722048274 0.152317365781542 1.28380453408541 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
2.008214,-0.0633337899773081 -1.38088970920094 0.958214701098423 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.0476928,-1.15393789990757 -0.961853075398404 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
2.1575593,0.0620203721138446 0.0657973885499142 1.22684714620405 -0.468824786336838 -0.522940888712441 1.31421001659859 1.72741139715549 -0.332627704725983
2.1916535,-0.75731027755674 -2.92717970468456 0.018001143228728 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.2137539,1.11226993252773 1.06484916245061 0.555266033439982 0.877691038550889 1.89254797819741 1.43890404648442 0.342627053981254 0.376490698755783
2.2772673,-0.468768642850639 -1.43754788774533 -1.05652863719378 0.576050411655607 -0.522940888712441 0.0120483832567209 0.342627053981254 -0.687186906466865
2.2975726,-0.618884859896728 -1.1366360750781 -0.519263746982526 -1.02470580167082 -0.522940888712441 -0.863171185425945 3.11219574032972 1.97200710658975
2.3272777,-0.651431999123483 0.55329161145762 -0.250631301876899 1.11210019001038 -0.522940888712441 -0.179808625688859 -1.04215728919298 -0.864466507337306
2.5217206,0.115499102435224 -0.512233676577595 0.286633588334355 1.13650173283446 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.155348103855541
2.5533438,0.266341329949937 -0.551137885443386 -0.384947524429713 0.354857790686005 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.5687881,1.16902610257751 0.855491905752846 2.03274448152093 1.22628985326088 1.89254797819741 2.02833774827712 3.11219574032972 2.68112551007152
2.6567569,-0.218972367124187 0.851192298581141 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 0.908329501367106
2.677591,0.263121415733908 1.4142681068416 0.018001143228728 1.35980653053822 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.7180005,-0.0704736333296423 1.52000996595417 0.286633588334355 1.39364261119802 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
2.7942279,-0.751957286017338 0.316843561689933 -1.99674219506348 0.911736065044475 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.8063861,-0.685277652430997 1.28214038482516 0.823898478545609 0.232904453212813 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
2.8124102,-0.244991501432929 0.51882005949686 -0.384947524429713 0.823246560137838 -0.522940888712441 -0.863171185425945 0.342627053981254 0.553770299626224
2.8419982,-0.75731027755674 2.09041984898851 1.22684714620405 1.53428167116843 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.8535925,1.20962937075363 -0.242882661178889 1.09253092365124 -1.02470580167082 -0.522940888712441 1.24263233939889 3.11219574032972 2.50384590920108
2.9204698,0.570886990493502 0.58243883987948 0.555266033439982 1.16006887775962 -0.522940888712441 1.07357183940747 0.342627053981254 1.61744790484887
2.9626924,0.719758684343624 0.984970304132004 1.09253092365124 1.52137230773457 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.509907305596424
2.9626924,-1.52406140158064 1.81975700990333 0.689582255992796 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2.9729753,-0.132431544081234 2.68769877553723 1.09253092365124 1.53428167116843 -0.522940888712441 -0.442797990776478 0.342627053981254 -0.687186906466865
3.0130809,0.436161292804989 -0.0834447307428255 -0.519263746982526 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
3.0373539,-0.161195191984091 -0.671900359186746 1.7641120364153 1.13650173283446 -0.522940888712441 -0.863171185425945 0.342627053981254 0.0219314970149
3.2752562,1.39927182372944 0.513852869452676 0.689582255992796 -1.02470580167082 1.89254797819741 1.49394503405693 0.342627053981254 -0.155348103855541
3.3375474,1.51967002306341 -0.852203755696565 0.555266033439982 -0.104527297798983 1.89254797819741 1.85927724828569 0.342627053981254 0.908329501367106
3.3928291,0.560725834706224 1.87867703391426 1.09253092365124 1.39364261119802 -0.522940888712441 0.486423065822545 0.342627053981254 1.26288870310799
3.4355988,1.00765532502814 1.69426310090641 1.89842825896812 1.53428167116843 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.509907305596424
3.4578927,1.10152996153577 -0.10927271844907 0.689582255992796 -1.02470580167082 1.89254797819741 1.97630171771485 0.342627053981254 1.61744790484887
3.5160131,0.100001934217311 -1.30380956369388 0.286633588334355 0.316555063757567 -0.522940888712441 0.28786643052924 0.342627053981254 0.553770299626224
3.5307626,0.987291634724086 -0.36279314978779 -0.922212414640967 0.232904453212813 -0.522940888712441 1.79270085261407 0.342627053981254 1.26288870310799
3.5652984,1.07158528137575 0.606453149641961 1.7641120364153 -0.432854616994416 1.89254797819741 0.528504607720369 0.342627053981254 0.199211097885341
3.5876769,0.180156323255198 0.188987436375017 -0.519263746982526 1.09956763075594 -0.522940888712441 0.708239632330506 0.342627053981254 0.199211097885341
3.6309855,1.65687973755377 -0.256675483533719 0.018001143228728 -1.02470580167082 1.89254797819741 1.79270085261407 0.342627053981254 1.26288870310799
3.6800909,0.5720085322365 0.239854450210939 -0.787896192088153 1.0605418233138 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
3.7123518,0.323806133438225 -0.606717660886078 -0.250631301876899 -1.02470580167082 1.89254797819741 0.342907418101747 0.342627053981254 0.199211097885341
3.9843437,1.23668206715898 2.54220539083611 0.152317365781542 -1.02470580167082 1.89254797819741 1.89037692416194 0.342627053981254 1.26288870310799
3.993603,0.180156323255198 0.154448192444669 1.62979581386249 0.576050411655607 1.89254797819741 0.708239632330506 0.342627053981254 1.79472750571931
4.029806,1.60906277046565 1.10378605019827 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
4.1295508,1.0036214996026 0.113496885050331 -0.384947524429713 0.860016436332751 1.89254797819741 -0.863171185425945 0.342627053981254 -0.332627704725983
4.3851468,1.25591974271076 0.577607033774471 0.555266033439982 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
4.6844434,2.09650591351268 0.625488598331018 -2.66832330782754 -1.02470580167082 1.89254797819741 1.67954222367555 0.342627053981254 0.553770299626224
5.477509,1.30028987435881 0.338383613253713 0.555266033439982 1.00481276295349 1.89254797819741 1.24263233939889 0.342627053981254 1.97200710658975
3.2、代码三:
package mllib import java.text.SimpleDateFormat import java.util.Date import org.apache.log4j.Level, Logger import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LinearRegressionWithSGD, LabeledPoint import org.apache.spark.SparkContext, SparkConf /** * 计算回归曲线的MSE * 对多组数据进行model的training,然后再利用model来predict具体的值 * 过程中有输出model的权重 * 公式:f(x)=a1X1+a2X2+a3X3+…… * Created by 汪本成 on 2016/8/7. */ object LinearRegression2 //屏蔽不必要的日志显示在终端上 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) //程序入口 val conf = new SparkConf() .setMaster("local[1]") .setAppName(this.getClass().getSimpleName().filter(!_.equals('$'))) println(this.getClass().getSimpleName().filter(!_.equals('$'))) val sc = new SparkContext(conf) def main(args: Array[String]) //获取数据集路径 val data = sc.textFile("G:\\\\MLlibData\\\\lpsa.data", 1) //处理数据集 val parsedData = data.map line => val parts = line.split(",") LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) //建立模型 //建立model的数据和predict的数据没有分开 val numIterations = 100 val model = LinearRegressionWithSGD.train(parsedData, numIterations, 0.1) //for (i <- parsedData) println(i.label + ":" + i.features) //获取真实值与预测值 val valuesAndPreds = parsedData.map point => //对系数进行预测 val prediction = model.predict(point.features) //按格式进行储存 (point.label, prediction) //打印权重 var weights = model.weights println("model.weights" + weights) //save as file val isString = new SimpleDateFormat("yyyyMMddHHmmssSSS").format(new Date()) val path = "G:\\\\MLlibData\\\\saveFile\\\\" + isString + "\\\\results" valuesAndPreds.saveAsTextFile(path) val MSE = valuesAndPreds.map case(v, p) => math.pow((v - p), 2) .reduce(_ + _ ) / valuesAndPreds.count println("训练的数据集的均方误差是: " + MSE) sc.stop() |
注意:MLlib中的线性回归比较适合做一元线性回归而非多元线性回归,当回归系数比较多时,算法产生的过拟合现象较为严重。
4、逻辑回归
4.1、数据
这里包括了我写的意愿逻辑回归和多元逻辑回归,数据用的是spark工程下的sample_libsvm_data.txt文件和我自己弄的logisticRegression1.data,内容如下:
1|2
1|3
1|4
1|5
1|6
0|7
0|8
0|9
0|10
0|11
4.2、代码四
package mllib import org.apache.log4j.Level, Logger import org.apache.spark.mllib.classification.LogisticRegressionModel, LogisticRegressionWithSGD import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector, Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext, SparkConf /** * 逻辑回归 * Created by 汪本成 on 2016/8/7. */ object LogisticRegression //屏蔽不必要的日志显示在终端上 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) val conf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass().getSimpleName().filter(!_.equals('$'))) val sc = new SparkContext(conf) var logisticRegression = new LogisticRegression //一元逻辑回归数据集 val LR1_PATH = "file\\\\data\\\\mllib\\\\input\\\\regression\\\\logisticRegression1.data" //多元逻辑回归数据集 val LR2_PATH = "file\\\\data\\\\mllib\\\\input\\\\regression\\\\sample_libsvm_data.txt" val data = sc.textFile(LR1_PATH) val svmData = MLUtils.loadLibSVMFile(sc, LR2_PATH) //分割数据集 val splits = svmData.randomSplit(Array(0.6, 0.4), seed = 11L) val parsedData_SVM = splits(0) val parsedTest_SVM = splits(1) //转化数据格式 val parsedData = data.map line => val parts = line.split('|') LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) .cache() //建立模型 val model = LogisticRegressionWithSGD.train(parsedData, 50) val svmModel = LogisticRegressionWithSGD.train(parsedData_SVM, 50) //创建测试值 val target = Vectors.dense(-1) //根据模型计算测试值结果 val predict = model.predict(target) //计算多元逻辑回归的测试值,并存储测试和预测值 val predict_svm = logisticRegression.predictAndLabels(parsedTest_SVM, svmModel) //创建验证类 val metrics = new MulticlassMetrics(predict_svm) //计算验证值 val precision = metrics.precision def main(args: Array[String]) println("一元逻辑回归:") parsedData.foreach(println) //打印权重 println("权重: " + model.weights) println(predict) println(model.predict(Vectors.dense(10))) println("*************************************************************") println("多元逻辑回归:") println("svmData记录数:" + svmData.count()) println("parsedData_SVM:" + parsedData_SVM.count()) println("parsedTest_SVM:" + parsedTest_SVM.count()) println("Precision = " + precision) //打印验证值 predict_svm.take(10).foreach(println) println("权重: " + svmModel.weights) println("weights 个数是: " + svmModel.weights.size) //打印weight不为0个数 println("weights不为0的个数是: " + model.weights.toArray.filter(_ != 0).size) sc.stop() class LogisticRegression /** * * @param data svmData * @param model LogisticRegressionModel * @return */ def predictAndLabels( data: RDD[LabeledPoint], model: LogisticRegressionModel):RDD[(Double, Double)]= val parsedData = data.map point => val prediction = model.predict(point.features) (point.label, prediction) parsedData |
以上是关于SparkMLlib---LinearRegression(线性回归)LogisticRegression(逻辑回归)的主要内容,如果未能解决你的问题,请参考以下文章