在 Spark 上训练 Kmeans 算法失败
Posted
技术标签:
【中文标题】在 Spark 上训练 Kmeans 算法失败【英文标题】:Training of Kmeans algorithm failed on Spark 【发布时间】:2020-07-29 12:55:57 【问题描述】:我创建了一个管道并尝试在 spark 中训练 Kmean 聚类算法,但它失败了,我无法找到确切的错误。这是代码
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.OneHotEncoderEstimator, StringIndexer, VectorAssembler, Normalizer
import org.apache.spark.SparkConf, SparkContext
import org.apache.spark.sql.SQLContext, SparkSession, functions
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType, IntegerType
val df = spark.read.option("header", "false").option("delimiter", " ").
csv("HMP_Dataset/*").
withColumn("Class" , element_at(reverse(split(input_file_name,"/")),2) ).
withColumn("Source" , element_at(reverse(split(input_file_name,"/")),1)).
withColumnRenamed("_c0","X").withColumnRenamed("_c1","Y").
withColumnRenamed("_c2","Z")
val df2 = df.select(
df.columns.map
case x @ "X" => df(x).cast(DoubleType).as(x)
case y @ "Y" => df(y).cast(DoubleType).as(y)
case z @ "Z" => df(z).cast(DoubleType).as(z)
case other => df(other)
: _*
)
val indexer = new StringIndexer().setInputCol("Class").setOutputCol("ClassIndex")
val encoder = new OneHotEncoderEstimator().setInputCols(Array("ClassIndex")) .setOutputCols(Array("CategoryVec"))
val assembler = new VectorAssembler().setInputCols(Array("X","Y","Z")).setOutputCol("Features")
val normalizer = new Normalizer().setInputCol("Features").setOutputCol("feature_Norm")
val pipeline = new Pipeline( ).setStages(Array ( indexer , encoder , assembler , normalizer) )
val model = pipeline.fit(df2).transform(df2)
val train = model.drop("X").drop("Y").drop("Z").drop("Class").drop("Source").drop("ClassIndex").drop("Features")
//model.show()
//train.show()
val kmeans = new KMeans().setFeaturesCol("feature_Norm").setK(2).setSeed(1).setMaxIter(100).fit(train).transform(train)
train 数据框创建成功,但是当我传递给 Kmeans 时,它会抛出错误。错误信息是
Failed to execute user defined function($anonfun$4: (struct<X:double,Y:double,Z:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
我该如何解决这个问题?
【问题讨论】:
你能写几行你想读的文件吗? github.com/wchill/HMP_Dataset @Chema 这是我要阅读的数据集的链接。 我可以看看你的进口数据吗? @Chema 我已经更新了问题。您现在可以导入了。 【参考方案1】:也许版本库和导入有问题,在我的笔记本电脑上代码可以正常工作。
我会告诉你我的.sbt
和代码产生的输出。
+--------------+-----------------------------------------------------------+----------+
|CategoryVec |feature_Norm |prediction|
+--------------+-----------------------------------------------------------+----------+
|(13,[0],[1.0])|[0.2574383611739353,0.6931032800836721,0.6733003292241385] |1 |
|(13,[0],[1.0])|[0.22614412777205142,0.6989909403863407,0.6784323833161543]|1 |
|(13,[0],[1.0])|[0.24551225268848764,0.675158694893341,0.6956180492840484] |1 |
|(13,[0],[1.0])|[0.2420417625303279,0.7059551407134563,0.6656148469584017] |1 |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1 |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1 |
|(13,[0],[1.0])|[0.2540244987629046,0.683912112053974,0.683912112053974] |1 |
|(13,[0],[1.0])|[0.2388089256503974,0.6766252893427926,0.6965260331469925] |1 |
|(13,[0],[1.0])|[0.2574383611739353,0.6733003292241385,0.6931032800836721] |1 |
|(13,[0],[1.0])|[0.2572366859677566,0.652985433610459,0.7123477457568644] |1 |
+--------------+-----------------------------------------------------------+----------+
+--------------+------------------------------------------------------------+----------+
|CategoryVec |feature_Norm |prediction|
+--------------+------------------------------------------------------------+----------+
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486] |0 |
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486] |0 |
|(13,[5],[1.0])|[0.46105396573580254,0.48899663032585117,0.7404806116362889]|0 |
|(13,[5],[1.0])|[0.4369231823814617,0.5214889596165833,0.7329034027043874] |0 |
|(13,[5],[1.0])|[0.45146611838648026,0.5078993831847903,0.7336324423780305] |0 |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371] |0 |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371] |0 |
|(13,[5],[1.0])|[0.45789190653985307,0.49951844349802155,0.7354021529276429]|0 |
|(13,[5],[1.0])|[0.4658526940598004,0.4940861906694853,0.7340709118518067] |0 |
|(13,[5],[1.0])|[0.4625915702820905,0.5046453493986442,0.7289321713535972] |0 |
+--------------+------------------------------------------------------------+----------+
build.sbt
scalaVersion := "2.11.10"
// https://mvnrepository.com/artifact/org.apache.spark/spark-mllib
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.2.0"
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.2.0"
libraryDependencies += "org.apache.spark" % "spark-sql_2.11" % "2.2.0"
进口
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.Normalizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.DoubleType
【讨论】:
以上是关于在 Spark 上训练 Kmeans 算法失败的主要内容,如果未能解决你的问题,请参考以下文章