spark pipeline KMeansModel clusterCenters

Posted

技术标签:

【中文标题】spark pipeline KMeansModel clusterCenters【英文标题】: 【发布时间】:2016-05-24 03:34:03 【问题描述】:

我正在使用管道对文本文档进行聚类。管道中的最后一个阶段是 ml.clustering.KMeans,它为我提供了一个带有一列集群预测的 DataFrame。我也想将集群中心添加为一列。我知道我可以执行 Vector[] clusterCenters = kmeansModel.clusterCenters(); 然后将结果转换为 DataFrame 并将所述结果连接到另一个 DataFrame 但是我希望找到一种方法来以类似于下面的 Kmeans 代码的方式完成此操作:

    KMeans kMeans = new KMeans()
                .setFeaturesCol("pca")
                .setPredictionCol("kmeansclusterprediction")
                .setK(5)
                .setInitMode("random")
                .setSeed(43L)
                .setInitSteps(3)
                .setMaxIter(15);

pipeline.setStages( ...

我能够扩展 KMeans 并通过管道调用 fit 方法,但是我没有任何运气扩展 KMeansModel ...构造函数需要 String uid 和 KMeansModel 但我不知道如何传入模型在定义阶段和调用 setStages 方法时。

我也研究过扩展 KMeans.scala,但是作为一名 Java 开发人员,我只了解大约一半的代码,因此我希望在我解决这个问题之前有人可能有一个更简单的解决方案。最终,我希望得到如下的 DataFrame:

+--------------------+-----------------------+--------------------+
|               docid|kmeansclusterprediction|kmeansclustercenters|
+--------------------+-----------------------+--------------------+
|2bcbcd54-c11a-48c...|                      2|      [-0.04, -7.72]|
|0e644620-f5ff-40f...|                      3|        [0.23, 1.08]|
|665c1c2b-3065-4e8...|                      3|        [0.23, 1.08]|
|598c6268-e4b9-4c9...|                      0|      [-15.81, 0.01]|
+--------------------+-----------------------+--------------------+ 

非常感谢任何帮助或提示。 谢谢

【问题讨论】:

【参考方案1】:

回答我自己的问题...这实际上很简单...我扩展了 KMeans 和 KMeansModel ...扩展的 Kmeans 拟合方法必须返回扩展的 KMeansModel。例如:

public class AnalyticsKMeansModel extends KMeansModel ...


public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans  ...

public AnalyticsKMeansModel fit(DataFrame dataset) 

    JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>()
        private static final long serialVersionUID = -4588981547209486909L;

        @Override
        public Vector call(Row row) throws Exception 
            Object point = row.getAs("pca");
            Vector vector = (Vector)point;
            return vector;
        

    );

    RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
    org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
    org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
    AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
    return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
 

一旦我更改了 fit 方法以返回我的扩展 KMeansModel 类,一切都按预期工作。

【讨论】:

嘿,我正在尝试做同样的事情,但无法理解您的代码,您能否发布完整的代码。 @DnA,我在上面添加了三个类...希望对您有所帮助。我很长时间没有接触过代码,但我相信它在最后一次接触时处于工作状态。该代码用于学习目的,从未用于生产环境。【参考方案2】:
        import java.util.ArrayList;
        import java.util.Arrays;
        import java.util.List;

        import org.apache.spark.api.java.JavaRDD;
        import org.apache.spark.api.java.JavaSparkContext;
        import org.apache.spark.api.java.function.Function;
        import org.apache.spark.ml.clustering.KMeansModel;
        import org.apache.spark.mllib.linalg.Vector;
        import org.apache.spark.sql.DataFrame;
        import org.apache.spark.sql.Row;
        import org.apache.spark.sql.RowFactory;
        import org.apache.spark.sql.types.DataTypes;
        import org.apache.spark.sql.types.StructField;
        import org.apache.spark.sql.types.StructType;

        import AnalyticsCluster;

        public class AnalyticsKMeansModel extends KMeansModel 
            private static final long serialVersionUID = -8893355418042946358L;

            public AnalyticsKMeansModel(String uid, org.apache.spark.mllib.clustering.KMeansModel parentModel) 
                super(uid, parentModel);
            

            public DataFrame transform(DataFrame dataset) 

                Vector[] clusterCenters = super.clusterCenters();

                List<AnalyticsCluster> analyticsClusters = new ArrayList<AnalyticsCluster>();

                for (int i=0; i<clusterCenters.length;i++)
                    Integer clusterId = super.predict(clusterCenters[i]);
                    Vector vector = clusterCenters[i];
                    double[] point = vector.toArray();
                    AnalyticsCluster analyticsCluster = new AnalyticsCluster(clusterId, point, 0L);
                    analyticsClusters.add(analyticsCluster);
                

                JavaSparkContext jsc = JavaSparkContext.fromSparkContext(dataset.sqlContext().sparkContext());

                JavaRDD<AnalyticsCluster> javaRDD = jsc.parallelize(analyticsClusters);

                JavaRDD<Row> javaRDDRow = javaRDD.map(new Function<AnalyticsCluster, Row>() 
                    private static final long serialVersionUID = -2677295862916670965L;

                    @Override
                    public Row call(AnalyticsCluster cluster) throws Exception 
                        Row row = RowFactory.create(
                            String.valueOf(cluster.getID()),
                            String.valueOf(Arrays.toString(cluster.getCenter()))
                        );
                        return row;
                    

                 );

                List<StructField> schemaColumns = new ArrayList<StructField>();
                schemaColumns.add(DataTypes.createStructField(this.getPredictionCol(), DataTypes.StringType, false));
                schemaColumns.add(DataTypes.createStructField("clusterpoint", DataTypes.StringType, false));

                StructType dataFrameSchema = DataTypes.createStructType(schemaColumns);

                DataFrame clusterPointsDF = dataset.sqlContext().createDataFrame(javaRDDRow, dataFrameSchema);

                //SOMETIMES "K" IS SET TO A VALUE GREATER THAN THE NUMBER OF ACTUAL ROWS OF DATA ... GET DISTINCT VALUES
                clusterPointsDF.registerTempTable("clusterPoints");
                DataFrame clustersDF = clusterPointsDF.sqlContext().sql("select distinct " + this.getPredictionCol()+ ", clusterpoint from clusterPoints");
                clustersDF.cache();
                clusterPointsDF.sqlContext().dropTempTable("clusterPoints");

                DataFrame transformedDF = super.transform(dataset);
                transformedDF.cache();

                DataFrame df = transformedDF.join(clustersDF,
                        transformedDF.col(this.getPredictionCol()).equalTo(clustersDF.col(this.getPredictionCol())), "inner")
                            .drop(clustersDF.col(this.getPredictionCol()));

                return df;
            
        





    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.ml.param.Param;
    import org.apache.spark.ml.param.Params;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.rdd.RDD;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;

    import scala.runtime.BoxesRunTime;

    public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans 
        private static final long serialVersionUID = 8943702485821267996L;
        private static String uid = null;

        public AnalyticsKMeans(String uid)
            AnalyticsKMeans.uid= uid;
        


        public AnalyticsKMeansModel fit(DataFrame dataset) 

            JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>()
                private static final long serialVersionUID = -4588981547209486909L;

                @Override
                public Vector call(Row row) throws Exception 
                    Object point = row.getAs("pca");
                    Vector vector = (Vector)point;
                    return vector;
                

            );

            RDD<Vector> rdd = JavaRDD.toRDD(javaRDD);
            org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol())));
            org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd);
            AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel);
            return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2());
        

    




import java.io.Serializable;
import java.util.Arrays;

public class AnalyticsCluster implements Serializable 
    private static final long serialVersionUID = 6535671221958712594L;

    private final int id;
    private volatile double[] center;
    private volatile long count;

    public AnalyticsCluster(int id, double[] center, long initialCount) 
    //      Preconditions.checkArgument(center.length > 0);
    //      Preconditions.checkArgument(initialCount >= 1);
        this.id = id;
        this.center = center;
        this.count = initialCount;
    

    public int getID() 
        return id;
    

    public double[] getCenter() 
        return center;
    

    public long getCount() 
        return count;
    

    public synchronized void update(double[] newPoint, long newCount) 
        int length = center.length;
    //      Preconditions.checkArgument(length == newPoint.length);
        double[] newCenter = new double[length];
        long newTotalCount = newCount + count;
        double newToTotal = (double) newCount / newTotalCount;
        for (int i = 0; i < length; i++) 
          double centerI = center[i];
          newCenter[i] = centerI + newToTotal * (newPoint[i] - centerI);
        
        center = newCenter;
        count = newTotalCount;
    

    @Override
    public synchronized String toString() 
        return id + " " + Arrays.toString(center) + " " + count;
    

//  public static void main(String[] args) 
//      double[] point = new double[2];
//      point[0] = 0.10150532938119154;
//      point[1] = -0.23734759238651829;
//      
//      Cluster cluster = new Cluster(1,point, 10L);
//      System.out.println("cluster: " + cluster.toString());
//  


【讨论】:

添加了上面的完整代码(3个类)。一年多来我没有接触过代码,这是我第一次尝试学习 spark。我相信我使用的是 spark 1.6

以上是关于spark pipeline KMeansModel clusterCenters的主要内容,如果未能解决你的问题,请参考以下文章

Spark2.0 Pipelines

spark pipeline KMeansModel clusterCenters

spark pipeline原理学习和记录

Spark ml pipeline - transforming feature - StringIndexer

Spark机器学习

有啥方法可以在 Spark ML Pipeline 中序列化自定义 Transformer