SparkMLlib---SGD随机梯度下降算法
Posted 汪本成
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SparkMLlib---SGD随机梯度下降算法相关的知识,希望对你有一定的参考价值。
代码:
package mllib import org.apache.log4j.Level, Logger import org.apache.spark.SparkContext, SparkConf import scala.collection.mutable.HashMap /** * 随机梯度下降算法 * Created by 汪本成 on 2016/8/5. */ object SGD //屏蔽不必要的日志显示在终端上 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) //程序入口 val conf = new SparkConf() .setMaster("local[1]") .setAppName(this.getClass().getSimpleName() .filter(!_.equals('$'))) println(this.getClass().getSimpleName().filter(!_.equals('$'))) val sc = new SparkContext(conf) //创建存储数据集HashMap集合 val data = new HashMap[Int, Int]() //生成数据集内容 def getData(): HashMap[Int, Int] = for(i <- 1 to 50) data += (i -> (2 * i)) //写入公式y=2x data //假设a=0 var a: Double = 0 //设置步进系数 var b: Double = 0.1 //设置迭代公式 def sgd(x: Double, y: Double) = a = a - b * ((a * x) - y) def main(args: Array[String]) //获取数据集 val dataSource = getData() println("data: ") dataSource.foreach(each => println(each + " ")) println("\\nresult: ") var num = 1 //开始迭代 dataSource.foreach(myMap => println(num + ":" + a + "("+myMap._1+","+myMap._2+")") sgd(myMap._1, myMap._2) num = num + 1 ) //显示结果 println("最终结果a为 " + a) |
运行结果:
"C:\\Program Files\\Java\\jdk1.8.0_77\\bin\\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\\Program Files (x86)\\JetBrains\\IntelliJ IDEA 15.0.5\\bin" -Dfile.encoding=UTF-8 -classpath "C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\charsets.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\deploy.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\access-bridge-64.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\cldrdata.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\dnsns.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\jaccess.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\jfxrt.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\localedata.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\nashorn.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunec.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunjce_provider.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunmscapi.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunpkcs11.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\zipfs.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\javaws.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jce.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jfr.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jfxswt.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jsse.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\management-agent.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\plugin.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\resources.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\rt.jar;G:\\location\\spark-mllib\\out\\production\\spark-mllib;C:\\Program Files (x86)\\scala\\lib\\scala-actors-migration.jar;C:\\Program Files (x86)\\scala\\lib\\scala-actors.jar;C:\\Program Files (x86)\\scala\\lib\\scala-library.jar;C:\\Program Files (x86)\\scala\\lib\\scala-reflect.jar;C:\\Program Files (x86)\\scala\\lib\\scala-swing.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-api-jdo-3.2.6.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-core-3.2.10.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-rdbms-3.2.9.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-1.6.1-yarn-shuffle.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-assembly-1.6.1-hadoop2.6.0.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-examples-1.6.1-hadoop2.6.0.jar;D:\\Program Files (x86)\\JetBrains\\IntelliJ IDEA 15.0.5\\lib\\idea_rt.jar" com.intellij.rt.execution.application.AppMain mllib.SGD
SGD
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-assembly-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-examples-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
16/08/05 00:48:28 INFO Slf4jLogger: Slf4jLogger started
16/08/05 00:48:28 INFO Remoting: Starting remoting
16/08/05 00:48:28 INFO Remoting: Remoting started; listening on addresses :[akka.tcp://sparkDriverActorSystem@192.168.43.1:24009]
data:
(23,46)
(50,100)
(32,64)
(41,82)
(17,34)
(8,16)
(35,70)
(44,88)
(26,52)
(11,22)
(29,58)
(38,76)
(47,94)
(20,40)
(2,4)
(5,10)
(14,28)
(46,92)
(40,80)
(49,98)
(4,8)
(13,26)
(22,44)
(31,62)
(16,32)
(7,14)
(43,86)
(25,50)
(34,68)
(10,20)
(37,74)
(1,2)
(19,38)
(28,56)
(45,90)
(27,54)
(36,72)
(18,36)
(9,18)
(21,42)
(48,96)
(3,6)
(12,24)
(30,60)
(39,78)
(15,30)
(42,84)
(24,48)
(6,12)
(33,66)
result:
1:0.0(23,46)
2:4.6000000000000005(50,100)
3:-8.400000000000002(32,64)
4:24.880000000000006(41,82)
5:-68.92800000000003(17,34)
6:51.649600000000035(8,16)
7:11.929920000000003(35,70)
8:-22.82480000000001(44,88)
9:86.40432000000006(26,52)
10:-133.04691200000013(11,22)
11:15.504691199999996(29,58)
12:-23.65891328(38,76)
13:73.84495718400001(47,94)
14:-263.82634158080003(20,40)
15:267.82634158080003(2,4)
16:214.66107326464004(5,10)
17:108.33053663232002(14,28)
18:-40.53221465292802(46,92)
19:155.1159727505409(40,80)
20:-457.3479182516227(49,98)
21:1793.4568811813288(4,8)
22:1076.8741287087973(13,26)
23:-320.46223861263934(22,44)
24:388.95468633516725(31,62)
25:-810.6048413038511(16,32)
26:489.56290478231085(7,14)
27:148.2688714346932(43,86)
28:-480.6872757344877(25,50)
29:726.0309136017315(34,68)
30:-1735.6741926441557(10,20)
31:2.0000000000002274(37,74)
32:1.999999999999386(1,2)
33:1.9999999999994476(19,38)
34:2.000000000000497(28,56)
35:1.9999999999991056(45,90)
36:2.00000000000313(27,54)
37:1.9999999999946787(36,72)
38:2.000000000013835(18,36)
39:1.999999999988932(9,18)
40:1.999999999998893(21,42)
41:2.0000000000012172(48,96)
42:1.9999999999953737(3,6)
43:1.9999999999967615(12,24)
44:2.000000000000648(30,60)
45:1.999999999998704(39,78)
46:2.0000000000037588(15,30)
47:1.9999999999981206(42,84)
48:2.0000000000060134(24,48)
49:1.999999999991581(6,12)
50:1.9999999999966325(33,66)
最终结果a为 2.0000000000077454
16/08/05 00:48:28 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.
Process finished with exit code 0
分析:
当α为0.1的时候,一般30次计算就计算出来了;如果是0.5,一般15次计算就有正确结果 。如果是1,则50次都没有结果
以上是关于SparkMLlib---SGD随机梯度下降算法的主要内容,如果未能解决你的问题,请参考以下文章