无法找出 Spark LinearRegression 错误的原因

Posted

技术标签:

【中文标题】无法找出 Spark LinearRegression 错误的原因【英文标题】:Can't figure out cause of Spark LinearRegression Error 【发布时间】:2016-09-08 00:19:48 【问题描述】:

我正在尝试使用我在 Kaggle 上找到的住房数据集在 PySpark 中做一个非常简单的LinearRegression。有一堆列,但为了使这(实际上)尽可能简单,我只保留了其中的两列(在从所有列开始之后),并且仍然没有运气训练模型。这是进行回归步骤之前数据框的样子:

2016-09-07 17:12:08,804 root INFO [Row(price=78000.0, sqft_living=780.0, sqft_lot=16344.0, features=DenseVector([780.0, 16344.0])), Row(price=80000.0, sqft_living=430.0, sqft_lot=5050.0, features=DenseVector([430.0, 5050.0])), Row(price=81000.0, sqft_living=730.0, sqft_lot=9975.0, features=DenseVector([730.0, 9975.0])), Row(price=82000.0, sqft_living=860.0, sqft_lot=10426.0, features=DenseVector([860.0, 10426.0])), Row(price=84000.0, sqft_living=700.0, sqft_lot=20130.0, features=DenseVector([700.0, 20130.0])), Row(price=85000.0, sqft_living=830.0, sqft_lot=9000.0, features=DenseVector([830.0, 9000.0])), Row(price=85000.0, sqft_living=910.0, sqft_lot=9753.0, features=DenseVector([910.0, 9753.0])), Row(price=86500.0, sqft_living=840.0, sqft_lot=9480.0, features=DenseVector([840.0, 9480.0])), Row(price=89000.0, sqft_living=900.0, sqft_lot=4750.0, features=DenseVector([900.0, 4750.0])), Row(price=89950.0, sqft_living=570.0, sqft_lot=4080.0, features=DenseVector([570.0, 4080.0]))]

我正在使用以下代码来训练模型:

    standard_scaler = StandardScaler(inputCol='features',
                                     outputCol='scaled')
    lr = LinearRegression(featuresCol=standard_scaler.getOutputCol(), labelCol='price', weightCol=None,
                          maxIter=100, tol=1e-4)
    pipeline = Pipeline(stages=[standard_scaler, lr])
    grid = (ParamGridBuilder()
            .baseOn(lr.labelCol: 'price')
            .addGrid(lr.regParam, [0.1, 1.0])
            .addGrid(lr.elasticNetParam, elastic_net_params or [0.0, 1.0])
            .build())
    ev = RegressionEvaluator(metricName="rmse", labelCol='price')
    cv = CrossValidator(estimator=pipeline,
                        estimatorParamMaps=grid,
                        evaluator=ev,
                        numFolds=5)
    model = cv.fit(data).bestModel

我得到的错误是:

2016-09-07 17:12:08,805 root INFO Training regression model...
2016-09-07 17:12:09,530 root ERROR An error occurred while calling o60.fit.
: java.lang.NullPointerException
    at org.apache.spark.ml.regression.LinearRegression.train(LinearRegression.scala:164)
    at org.apache.spark.ml.regression.LinearRegression.train(LinearRegression.scala:70)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:71)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:280)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:211)
    at java.lang.Thread.run(Thread.java:745)

有什么想法吗?

【问题讨论】:

感谢@evan-zamir。我遇到了同样的错误,您在下面的评论解决了它。只是我没有使用 1.0 作为权重,而是删除了权重参数。这看起来像是 Spark 中应该报告的错误。 【参考方案1】:

在这种情况下,您不能使用Pipeline。当您致电 pipeline.fit 时,它会转换为(大致)

standard_scaler_model = standard_scaler.fit(dataframe)
lr_model = lr.fit(dataframe)

但你确实需要

standard_scaler_model = standard_scaler.fit(dataframe)
dataframe = standard_scaler_model.transform(dataframe)
lr_model = lr.fit(dataframe)

错误是因为您的lr.fit 找不到您的StandardScaler 模型的输出(即转换的结果)。

【讨论】:

这里的错误不是由StandardScaler引起的。这对我来说很好(显然你的经历不同)。错误原来是weight 列。当我尝试指定 weightCol=None 时,这对我来说是错误的。我通过创建一个权重为 1.0 的 weightCol 来修复它(必须是浮点数!)。

以上是关于无法找出 Spark LinearRegression 错误的原因的主要内容,如果未能解决你的问题,请参考以下文章

Spark:如何找出两个数据集之间的不同元素?

找出 2 个表 (`tbl_spark`) 是不是相等而不使用 sparklyr 收集它们

有没有办法找出 Spark Web UI 正在使用的端口?

Spark SQL 窗口平均值问题

Apache Spark Hadoop S3A SignatureDoesNotMatch

从 Spark 调用休息服务