在我的索引数据集上遇到 OneHotEncoding 错误

Posted

技术标签:

【中文标题】在我的索引数据集上遇到 OneHotEncoding 错误【英文标题】:Running into OneHotEncoding error on my indexed dataset 【发布时间】:2022-01-20 14:39:01 【问题描述】:

我是 spark 和 scala 的新手,所以我有点困惑,但本质上我有一个公共的 covid 数据集,我用它来运行随机森林 ml 算法。问题是我能够平衡和索引我的数据集,但是当我运行 OneHotEncoder 时出现以下错误。我无法理解为什么,我什至在网上搜索了解决方案。这是我目前的代码

这里是数据集https://open.toronto.ca/dataset/covid-19-cases-in-toronto/的链接

/* Start spark */

spark-shell --master yarn

/*Import statements*/

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.ml.feature.VectorAssembler, StringIndexer
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.Pipeline, PipelineModel
import org.apache.spark.ml.classification.RandomForestClassificationModel, RandomForestClassifier
import org.apache.spark.ml.tuning.CrossValidator, CrossValidatorModel, ParamGridBuilder
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.types.IntegerType, DoubleType
import org.apache.spark.sql.SparkSession


/* Load the dataset */

val dataset = spark.read
.format("csv")
.option("header", "true")
.load("hdfs://10.128.0.2/BigData/covid19_cases.csv")

/* filter the dataset to show only required variables */

val dataset_filter = dataset.select(col("Age Group"), col("Classification"), 
column("Client Gender"), column("Outcome"), column("Currently Hospitalized"), 
column("Currently in ICU"), column("Currently Intubated"),column("Ever Hospitalized"),
column("Ever in ICU"),column("Ever Intubated"))

/* renaming the features */

val dataset_rename = dataset_filter.withColumnRenamed("Age Group", "age_group")
.withColumnRenamed("Classification","classification")
.withColumnRenamed("Client Gender","gender")
.withColumnRenamed("Currently Hospitalized","hospitalized")
.withColumnRenamed("Currently in ICU","in_icu")
.withColumnRenamed("Currently Intubated","intubated")
.withColumnRenamed("Ever Hospitalized","ever_hospitalized")
.withColumnRenamed("Ever in ICU","ever_icu")
.withColumnRenamed("Ever Intubated","ever_intubated")
.withColumnRenamed("Outcome","outcome")

/* balancing the datset */

val dataset_fatality = dataset_rename.filter(dataset_rename("outcome") === "FATAL")
val dataset_nonfatality = dataset_rename.filter(dataset_rename("outcome") === "RESOLVED")

val fatality_sample = dataset_fatality.count().toDouble / dataset_rename.count().toDouble
val nonfatality_sample = dataset_nonfatality.sample(false,fatality_sample)

val dataset_balanced = dataset_fatality.unionAll(nonfatality_sample)

/*multi indexing the dataset*/

val inputs = Array("age_group","classification","gender",
"hospitalized","in_icu","intubated","ever_hospitalized","ever_icu",
"ever_intubated")

val outputs = Array("age_indexed","classification_indexed","gender_indexed",
"hospitalized_indexed","in_icu_indexed","intubated_indexed","ever_hospitalized_indexed",
"ever_icu_indexed","ever_intubated_indexed")

val indexer = new StringIndexer()
indexer.setInputCols(inputs)
indexer.setOutputCols(outputs)

val StrIndexer = new StringIndexer()
.setInputCol("outcome")
.setOutputCol("outcome_indexed")

val df_indexed = indexer.fit(dataset_balanced).transform(dataset_balanced)
val df_indexed2 = StrIndexer.fit(df_indexed).transform(df_indexed)

val dataset_rank = df_indexed2.select(col("outcome_indexed").cast(DoubleType),
col("age_indexed").cast(DoubleType),
col("classification_indexed").cast(DoubleType),
col("gender_indexed").cast(DoubleType),
col("hospitalized_indexed").cast(DoubleType),
col("in_icu_indexed").cast(DoubleType),
col("intubated_indexed").cast(DoubleType),
col("ever_hospitalized_indexed").cast(DoubleType),
col("ever_icu_indexed").cast(DoubleType),
col("ever_intubated_indexed").cast(DoubleType))

/* One hot encoding after indexing the dataset */

val encoder = new OneHotEncoder()
.setInputCols(Array("age_indexed", "classification_indexed","gender_indexed","hospitalized_indexed",
"in_icu_indexed","intubated_indexed","ever_hospitalized_indexed","ever_icu_indexed","ever_intubated_indexed"))

.setOutputCols(Array("age_vec", "class_vec", "gender_vec","hospitalized_vec","in_icu_vec","intubated_vec",
"ever_hospitalized_vec", "ever_icu_vec","ever_intubated_vec"))

val dataset_encoder = encoder.fit(dataset_rank).transform(dataset_rank)

这是我在最后一步收到的错误消息。

21/12/17 17:39:55 WARN org.apache.spark.scheduler.TaskSetManager: Lost task 2.0 in stage 28.0 (TID 56) (spark-m.us-central1-a.c.sp
arkclust-335416.internal executor 3): org.apache.spark.SparkException: Failed to execute user defined function(StringIndexerModel$
$Lambda$4600/1280945972: (string) => double)
        at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
        at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
        at scala.collection.Iterator.foreach(Iterator.scala:943)
        at scala.collection.Iterator.foreach$(Iterator.scala:943)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
        at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
        at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
        at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
        at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
        at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
        at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
        at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$3(RDD.scala:1230)
        at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$5(RDD.scala:1231)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:863)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:863)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:131)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: StringIndexer encountered NULL value. To handle or skip NULLS, try setting StringIndex
er.handleInvalid.
        at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1(StringIndexer.scala:396)
        at org.apache.spark.ml.feature.StringIndexerModel.$anonfun$getIndexer$1$adapted(StringIndexer.scala:391)
        ... 30 more

【问题讨论】:

你有空值。 【参考方案1】:

你需要

setHandleInvalid("skip") or "keep" 

以适当的方式添加。

或者确保相关输入数据中不存在空值。

https://weishungchung.com/2017/08/14/stringindexer-transform-fails-when-column-contains-nulls/ 是一个很好的例子。

【讨论】:

以上是关于在我的索引数据集上遇到 OneHotEncoding 错误的主要内容,如果未能解决你的问题,请参考以下文章

如何在自定义数据集上执行RCNN对象检测?

为啥每次我在这个特定数据集上运行 train-test split 时我的内核都会死掉?

Python Pandas 索引错误:列表索引超出范围

优化解决方案以在大型数据集上找到共同的第三个

保存的随机森林模型在同一数据集上产生不同的结果

树莓派图像分割