Spark 回归仅适用于一项功能
Posted
技术标签:
【中文标题】Spark 回归仅适用于一项功能【英文标题】:Spark regression only working with one feature 【发布时间】:2021-07-23 15:31:44 【问题描述】:我有一些关于服务提供商客户的数据 (~1 MB)。我正在尝试使用 Spark(Databricks 上的 PySpark)预测他们是否会根据一些功能终止订阅(流失)。
单一特征模型
一开始,我只尝试了一个功能并看到了成功的训练:
# Create vector assembler to merge independent features (in this case just one) into one feature as a list
vectorAssembler = VectorAssembler(inputCols=['MonthlyCharges'], outputCol='Charges')
# Create a logistic regressor instance to take in this list ('Charges') and use churn labels
lr = LogisticRegression(featuresCol='Charges', labelCol='Churn_indexed')
# Select the two relevant columns an put in a new dataframe
# NOTE: Is this actually hurting performance by using extra memory?
# I wasn't sure if it would expedite the vector assembler transformation
relevant = df_num.select(['MonthlyCharges', 'Churn_indexed'])
# Transform the data using the Assembler and then dump the unwanted column ('Monthly Charges)
# NOTE: Is this selection also not necessary because 'lr' already knows which feature column to use?
curr = vectorAssembler2.transform(relevant).select(['Charges', 'Churn_indexed'])
# Create train/test split
train2, test2 = curr.randomSplit([0.8, 0.2], seed=42)
# Fit the model
model = lr.fit(train2)
双特征模型
但是,当我尝试使用 两个 独立的功能时,我遇到了错误
vectorAssembler2 = VectorAssembler(inputCols=['MonthlyCharges', 'TotalCharges'], outputCol='Charges')
lr2 = LogisticRegression(featuresCol='Charges', labelCol='Churn_indexed')
relevant = df_num.select(['MonthlyCharges', 'TotalCharges', 'Churn_indexed'])
curr = vectorAssembler2.transform(relevant).select(['Charges', 'Churn_indexed'])
train2, test2 = curr.randomSplit([0.8, 0.2], seed=42)
model = lr2.fit(train2)
错误
这是错误:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 30.0 failed 1 times, most recent failure: Lost task 0.0 in stage 30.0 (TID 29) (ip-10-172-254-69.us-west-2.compute.internal executor driver): org.apache.spark.SparkException: Failed to execute user defined function(VectorAssembler$$Lambda$5900/1716232969: (struct<MonthlyCharges:double,TotalCharges:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
并展开它会显示此错误:
Py4JJavaError Traceback (most recent call last)
<command-1815097094215178> in <module>
11 curr.show()
12
---> 13 model = lr.fit(train2)
/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py in patched_method(self, *args, **kwargs)
28 call_succeeded = False
29 try:
---> 30 result = original_method(self, *args, **kwargs)
31 call_succeeded = True
32 return result
/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
159 return self.copy(params)._fit(dataset)
160 else:
--> 161 return self._fit(dataset)
162 else:
163 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
/databricks/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset)
333
334 def _fit(self, dataset):
--> 335 java_model = self._fit_java(dataset)
336 model = self._create_model(java_model)
337 return self._copyValues(model)
/databricks/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset)
330 """
331 self._transfer_params_to_java()
--> 332 return self._java_obj.fit(dataset._jdf)
333
334 def _fit(self, dataset):
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
1302
1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
1305 answer, self.gateway_client, self.target_id, self.name)
1306
/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
115 def deco(*a, **kw):
116 try:
--> 117 return f(*a, **kw)
118 except py4j.protocol.Py4JJavaError as e:
119 converted = convert_exception(e.java_exception)
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling 012.\n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o1465.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 30.0 failed 1 times, most recent failure: Lost task 0.0 in stage 30.0 (TID 29) (ip-10-172-254-69.us-west-2.compute.internal executor driver): org.apache.spark.SparkException: Failed to execute user defined function(VectorAssembler$$Lambda$5900/1716232969: (struct<MonthlyCharges:double,TotalCharges:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.sort_addToSorter_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:757)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:162)
at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:160)
at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1429)
at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:219)
at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:219)
at scala.collection.AbstractIterator.aggregate(Iterator.scala:1429)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$3(RDD.scala:1240)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$5(RDD.scala:1241)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:868)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:868)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:60)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:380)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:344)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:91)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:788)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:791)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:647)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "error". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1(VectorAssembler.scala:292)
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1$adapted(VectorAssembler.scala:261)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.ml.feature.VectorAssembler$.assemble(VectorAssembler.scala:261)
at org.apache.spark.ml.feature.VectorAssembler.$anonfun$transform$6(VectorAssembler.scala:144)
... 41 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2765)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2712)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2706)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2706)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1255)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1255)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1255)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2973)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2914)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2902)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1028)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2446)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2429)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2541)
at org.apache.spark.rdd.RDD.$anonfun$fold$1(RDD.scala:1193)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
at org.apache.spark.rdd.RDD.fold(RDD.scala:1187)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$1(RDD.scala:1256)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1232)
at org.apache.spark.ml.stat.Summarizer$.getClassificationSummarizers(Summarizer.scala:232)
at org.apache.spark.ml.classification.LogisticRegression.$anonfun$train$1(LogisticRegression.scala:513)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:284)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:284)
at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:497)
at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:288)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
at py4j.Gateway.invoke(Gateway.java:295)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:251)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(VectorAssembler$$Lambda$5900/1716232969: (struct<MonthlyCharges:double,TotalCharges:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.sort_addToSorter_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:757)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:162)
at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:160)
at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1429)
at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:219)
at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:219)
at scala.collection.AbstractIterator.aggregate(Iterator.scala:1429)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$3(RDD.scala:1240)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$5(RDD.scala:1241)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:868)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:868)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:60)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:380)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:344)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:91)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:788)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:791)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:647)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "error". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1(VectorAssembler.scala:292)
at org.apache.spark.ml.feature.VectorAssembler$.$anonfun$assemble$1$adapted(VectorAssembler.scala:261)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at org.apache.spark.ml.feature.VectorAssembler$.assemble(VectorAssembler.scala:261)
at org.apache.spark.ml.feature.VectorAssembler.$anonfun$transform$6(VectorAssembler.scala:144)
... 41 more
有谁知道是什么导致了这种行为?非常感谢任何帮助,谢谢!
额外信息:
15 GB,2 核集群 DBR 8.3 Spark 3.1.1 Scala 2.12编辑:
-
数据是电话公司的客户流失数据。见这里:
https://www.kaggle.com/blastchar/telco-customer-churn
为了测试问题所在,我只使用了一项功能(按月收费)。这是数据的屏幕截图(费用已矢量化)
然后我还添加了总费用,这给了我一个错误,就像我使用所有功能时一样。以下是使用每月和总费用得出的数据:
-
添加了完整的错误!
【问题讨论】:
您能否提供您正在使用的数据样本?如果它是私有的,您可以将其匿名化 此外,Params must be either a param map or a list/tuple of param maps
行后面的错误是什么意思?它应该说类似but got ...
嗨 Ric S - 我添加了信息,谢谢!
【参考方案1】:
错误的原因是因为你的数据包含空值
原因:org.apache.spark.SparkException:在使用 handleInvalid = "error" 组装一行时遇到 null。考虑 从数据集中删除空值或使用 handleInvalid = "keep" 或 "skip"。
这是您从 Kaggle 共享的数据的空值计数
df = spark.read.option("header", True).csv('WA_Fn-UseC_-Telco-Customer-Churn.csv')
print(col:df.filter(df[col].cast('float').isNull()).count() for col in ['MonthlyCharges', 'TotalCharges'])
# 'MonthlyCharges': 0, 'TotalCharges': 11
意味着当你的模型只使用MonthlyCharges
时,它工作得很好,因为没有空值。但是当你包含TotalCharges
并且一些空值在训练集中时,它会抛出上述错误。
尝试使用.fillna(0)
用零填充空值
relevant = df_num.select(['MonthlyCharges', 'TotalCharges', 'Churn_indexed']).fillna(0)
【讨论】:
嗨 AdibP,看起来你是对的,但我不知道为什么。我将数据加载到名为df
的df 中,但随后运行df = df.na.drop()
,它应该已经摆脱了这些值,不是吗? df_ind
然后将分类数据转换为数值,然后df_num
简单地将所有数值数据隔离到一个 df 中。即使在运行您提供的 .fillna(0)
行之后,relevant
df still 表示它在 TotalCharges 列中有 11 个 NaN/null,但它确实最终完成了训练,所以我想这已经解决了。感谢您的帮助!以上是关于Spark 回归仅适用于一项功能的主要内容,如果未能解决你的问题,请参考以下文章