如何在我的 Spark 管道中集成 ALS 以实现非负矩阵分解?

Posted

技术标签:

【中文标题】如何在我的 Spark 管道中集成 ALS 以实现非负矩阵分解?【英文标题】:How to integrate ALS in my spark pipeline to implement Non-negative matrix factorization? 【发布时间】:2018-01-06 15:04:48 【问题描述】:

我正在使用 spark mllib 训练朴素贝叶斯分类器模型,在该模型中我创建了一个管道来索引我的字符串特征,然后规范化并应用 PCA 进行降维,然后我训练我的朴素贝叶斯模型。当我运行管道时,我在 PCA 分量向量中得到负值。在谷歌搜索时,我发现我必须应用 NMF(非负矩阵分解)来获得正向量,我发现 ALS 将使用方法 .setnonnegative(true) 实现 NMF ,但我不知道如何在 PCA 之后将 ALS 集成到我的管道中。任何帮助表示赞赏。谢谢。

这里是代码

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

public class NBTrainPCA 
    public static void main(String args[])
        try
            SparkConf conf = new SparkConf().setAppName("NBTrain");
            SparkContext scc = new SparkContext(conf);
            scc.setLogLevel("ERROR");
            JavaSparkContext sc = new JavaSparkContext(scc);
            SQLContext sqlc = new SQLContext(scc);
            DataFrame traindata = sqlc.read().format("parquet").load(args[0]).filter("user_email!='NA' and user_email!='00' and user_email!='0ed709b5bec77b6bff96ea5b5e334a8e5' and user_email is not null  and ip is not null  and region_code is not null and city is not null and browser_name is not null and os_name is not null");
            traindata.registerTempTable("master");
            //DataFrame data = sqlc.sql("select user_email,user_device,ip,country_code,region_code,city,zip_code,time_zone,browser_name,browser_manf,os_name,os_manf from master where user_email!='NA' and user_email is not null and user_device is not null and ip is not null and country_code is not null and region_code is not null and city is not null and browser_name is not null and browser_manf is not null and zip_code is not null and time_zone is not null and os_name is not null and os_manf is not null");
            StringIndexerModel emailIndexer = new StringIndexer()
              .setInputCol("user_email")
              .setOutputCol("email_index")
              .setHandleInvalid("skip")
              .fit(traindata);
            StringIndexer udevIndexer = new StringIndexer()
              .setInputCol("user_device")
              .setOutputCol("udev_index")
              .setHandleInvalid("skip");
            StringIndexer ipIndexer = new StringIndexer()
              .setInputCol("ip")
              .setOutputCol("ip_index")
              .setHandleInvalid("skip");
            StringIndexer ccodeIndexer = new StringIndexer()
              .setInputCol("country_code")
              .setOutputCol("ccode_index")
              .setHandleInvalid("skip");
            StringIndexer rcodeIndexer = new StringIndexer()
              .setInputCol("region_code")
              .setOutputCol("rcode_index")
              .setHandleInvalid("skip");
            StringIndexer cyIndexer = new StringIndexer()
              .setInputCol("city")
              .setOutputCol("cy_index")
              .setHandleInvalid("skip");
            StringIndexer zpIndexer = new StringIndexer()
              .setInputCol("zip_code")
              .setOutputCol("zp_index")
              .setHandleInvalid("skip");
            StringIndexer tzIndexer = new StringIndexer()
              .setInputCol("time_zone")
              .setOutputCol("tz_index")
              .setHandleInvalid("skip");
            StringIndexer bnIndexer = new StringIndexer()
              .setInputCol("browser_name")
              .setOutputCol("bn_index")
              .setHandleInvalid("skip");
            StringIndexer bmIndexer = new StringIndexer()
              .setInputCol("browser_manf")
              .setOutputCol("bm_index")
              .setHandleInvalid("skip");
            StringIndexer bvIndexer = new StringIndexer()
              .setInputCol("browser_version")
              .setOutputCol("bv_index")
              .setHandleInvalid("skip");
            StringIndexer onIndexer = new StringIndexer()
              .setInputCol("os_name")
              .setOutputCol("on_index")
              .setHandleInvalid("skip");
            StringIndexer omIndexer = new StringIndexer()
              .setInputCol("os_manf")
              .setOutputCol("om_index")
              .setHandleInvalid("skip");
            VectorAssembler assembler = new VectorAssembler()
              .setInputCols(new String[] "udev_index","ip_index","ccode_index","rcode_index","cy_index","zp_index","tz_index","bn_index","bm_index","bv_index","on_index","om_index")
              .setOutputCol("ffeatures");
            Normalizer normalizer = new Normalizer()
              .setInputCol("ffeatures")
              .setOutputCol("sfeatures")
              .setP(1.0);
            PCA pca = new PCA()
                .setInputCol("sfeatures")
                .setOutputCol("pcafeatures")
                .setK(5);
            NaiveBayes nbcl = new NaiveBayes()
            .setFeaturesCol("pcafeatures")
            .setLabelCol("email_index")
            .setSmoothing(1.0);
            IndexToString is = new IndexToString()
            .setInputCol("prediction")
            .setOutputCol("op")
            .setLabels(emailIndexer.labels());
            Pipeline pipeline = new Pipeline()
              .setStages(new PipelineStage[] emailIndexer,udevIndexer,ipIndexer,ccodeIndexer,rcodeIndexer,cyIndexer,zpIndexer,tzIndexer,bnIndexer,bmIndexer,bvIndexer,onIndexer,omIndexer,assembler,normalizer,pca,nbcl,is);
            PipelineModel model = pipeline.fit(traindata);
            //DataFrame chidata = model.transform(data);
            //chidata.write().format("com.databricks.spark.csv").save(args[1]);
            model.write().overwrite().save(args[1]);
            sc.close();
            
            catch(Exception e)

            
    

【问题讨论】:

【参考方案1】:

我建议您阅读一些有关 PCA 的内容,以便更好地了解它的作用。这里有一些链接:

https://stats.stackexchange.com/questions/26352/interpreting-positive-and-negative-signs-of-the-elements-of-pca-eigenvectors

https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues

在将 ALS 集成到您的管道时,您似乎只想一个接一个地插入一个东西。更好地了解他们每个人的工作和用途:ALS 和 PCA 是完全不同的东西。 ALS 正在使用 ALS 进行矩阵分解以最小化误差,没有找到任何主成分来对数据应用变换或降维。

顺便说一句:我认为在 PCA 分量向量中获取负值没有任何问题。您可以在上面的链接中检查这一点。您正在对数据应用线性变换。所以新向量现在是转换的结果。 我希望它有所帮助。

【讨论】:

在 PCA 分量向量中获取负值存在问题,朴素贝叶斯在特征集中没有占用负值。这正是问题所在。 参考此链接***.com/questions/36491852/… 阅读那里的评论:“NMF 在 Spark 中实现,它在分解原始矩阵时不考虑正交性,因此它可能不适用于您的应用程序。” ALS 矩阵分解没有做任何接近 PCA 的事情。 在我上面的回复中,PCA 可以给出负特征值。如果朴素贝叶斯不能处理它们,那么也许考虑另一种方法。关键是 ALS 不会成为解决方案。

以上是关于如何在我的 Spark 管道中集成 ALS 以实现非负矩阵分解?的主要内容,如果未能解决你的问题,请参考以下文章

如何在我的 Android 应用程序中集成 Quick Blox 视频聊天服务?

iOS - 如何在我的应用程序中集成蓝牙设备

在我的 Reactjs 网络应用程序中集成谷歌地图后,如何获得“纬度”和“经度”?

如何在我的网站中集成 CCavenue 支付网关?

如何在我的应用程序中集成 ATOM 支付网关?

如何在我的页面中集成 Google Analytics 数据?