在我的索引数据集上遇到 OneHotEncoding 错误
Posted
技术标签:
【中文标题】在我的索引数据集上遇到 OneHotEncoding 错误【英文标题】:Running into OneHotEncoding error on my indexed dataset 【发布时间】:2022-01-20 14:39:01 【问题描述】:我是 spark 和 scala 的新手,所以我有点困惑,但本质上我有一个公共的 covid 数据集,我用它来运行随机森林 ml 算法。问题是我能够平衡和索引我的数据集,但是当我运行 OneHotEncoder 时出现以下错误。我无法理解为什么,我什至在网上搜索了解决方案。这是我目前的代码
这里是数据集https://open.toronto.ca/dataset/covid-19-cases-in-toronto/的链接
/* Start spark */
spark-shell --master yarn
/*Import statements*/
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.ml.feature.VectorAssembler, StringIndexer
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.Pipeline, PipelineModel
import org.apache.spark.ml.classification.RandomForestClassificationModel, RandomForestClassifier
import org.apache.spark.ml.tuning.CrossValidator, CrossValidatorModel, ParamGridBuilder
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.types.IntegerType, DoubleType
import org.apache.spark.sql.SparkSession
/* Load the dataset */
val dataset = spark.read
.format("csv")
.option("header", "true")
.load("hdfs://10.128.0.2/BigData/covid19_cases.csv")
/* filter the dataset to show only required variables */
val dataset_filter = dataset.select(col("Age Group"), col("Classification"),
column("Client Gender"), column("Outcome"), column("Currently Hospitalized"),
column("Currently in ICU"), column("Currently Intubated"),column("Ever Hospitalized"),
column("Ever in ICU"),column("Ever Intubated"))
/* renaming the features */
val dataset_rename = dataset_filter.withColumnRenamed("Age Group", "age_group")
.withColumnRenamed("Classification","classification")
.withColumnRenamed("Client Gender","gender")
.withColumnRenamed("Currently Hospitalized","hospitalized")
.withColumnRenamed("Currently in ICU","in_icu")
.withColumnRenamed("Currently Intubated","intubated")
.withColumnRenamed("Ever Hospitalized","ever_hospitalized")
.withColumnRenamed("Ever in ICU","ever_icu")
.withColumnRenamed("Ever Intubated","ever_intubated")
.withColumnRenamed("Outcome","outcome")
/* balancing the datset */
val dataset_fatality = dataset_rename.filter(dataset_rename("outcome") === "FATAL")
val dataset_nonfatality = dataset_rename.filter(dataset_rename("outcome") === "RESOLVED")
val fatality_sample = dataset_fatality.count().toDouble / dataset_rename.count().toDouble
val nonfatality_sample = dataset_nonfatality.sample(false,fatality_sample)
val dataset_balanced = dataset_fatality.unionAll(nonfatality_sample)
/*multi indexing the dataset*/
val inputs = Array("age_group","classification","gender",
"hospitalized","in_icu","intubated","ever_hospitalized","ever_icu",
"ever_intubated")
val outputs = Array("age_indexed","classification_indexed","gender_indexed",
"hospitalized_indexed","in_icu_indexed","intubated_indexed","ever_hospitalized_indexed",
"ever_icu_indexed","ever_intubated_indexed")
val indexer = new StringIndexer()
indexer.setInputCols(inputs)
indexer.setOutputCols(outputs)
val StrIndexer = new StringIndexer()
.setInputCol("outcome")
.setOutputCol("outcome_indexed")
val df_indexed = indexer.fit(dataset_balanced).transform(dataset_balanced)
val df_indexed2 = StrIndexer.fit(df_indexed).transform(df_indexed)
val dataset_rank = df_indexed2.select(col("outcome_indexed").cast(DoubleType),
col("age_indexed").cast(DoubleType),
col("classification_indexed").cast(DoubleType),
col("gender_indexed").cast(DoubleType),
col("hospitalized_indexed").cast(DoubleType),
col("in_icu_indexed").cast(DoubleType),
col("intubated_indexed").cast(DoubleType),
col("ever_hospitalized_indexed").cast(DoubleType),
col("ever_icu_indexed").cast(DoubleType),
col("ever_intubated_indexed").cast(DoubleType))
/* One hot encoding after indexing the dataset */
val encoder = new OneHotEncoder()
.setInputCols(Array("age_indexed", "classification_indexed","gender_indexed","hospitalized_indexed",
"in_icu_indexed","intubated_indexed","ever_hospitalized_indexed","ever_icu_indexed","ever_intubated_indexed"))
.setOutputCols(Array("age_vec", "class_vec", "gender_vec","hospitalized_vec","in_icu_vec","intubated_vec",
"ever_hospitalized_vec", "ever_icu_vec","ever_intubated_vec"))
val dataset_encoder = encoder.fit(dataset_rank).transform(dataset_rank)
这是我在最后一步收到的错误消息。
21/12/17 17:39:55 WARN org.apache.spark.scheduler.TaskSetManager: Lost task 2.0 in stage 28.0 (TID 56) (spark-m.us-central1-a.c.sp
arkclust-335416.internal executor 3): org.apache.spark.SparkException: Failed to execute user defined function(StringIndexerModel$
$Lambda$4600/1280945972: (string) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$3(RDD.scala:1230)
at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$5(RDD.scala:1231)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:863)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:863)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:131)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: StringIndexer encountered NULL value. To handle or skip NULLS, try setting StringIndex
er.handleInvalid.
at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1(StringIndexer.scala:396)
at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1$adapted(StringIndexer.scala:391)
... 30 more
【问题讨论】:
你有空值。 【参考方案1】:你需要
setHandleInvalid("skip") or "keep"
以适当的方式添加。
或者确保相关输入数据中不存在空值。
https://weishungchung.com/2017/08/14/stringindexer-transform-fails-when-column-contains-nulls/ 是一个很好的例子。
【讨论】:
以上是关于在我的索引数据集上遇到 OneHotEncoding 错误的主要内容,如果未能解决你的问题,请参考以下文章