向 DataFrame 添加一列,其值为 1,其中预测大于自定义阈值
Posted
技术标签:
【中文标题】向 DataFrame 添加一列,其值为 1,其中预测大于自定义阈值【英文标题】:Add a column to DataFrame with value of 1 where prediction greater than a custom threshold 【发布时间】:2017-05-12 02:34:14 【问题描述】:我正在尝试向DataFrame
添加一列,当输出类概率很高时,该列的值应为 1。像这样的:
val output = predictions
.withColumn(
"easy",
when( $"label" === $"prediction" &&
$"probability" > 0.95, 1).otherwise(0)
)
问题是,probability
是Vector
,而0.95
是Double
,所以上面的行不通。我真正需要的更像是max($"probability") > 0.95
,但当然这也行不通。
实现此目的的正确方法是什么?
【问题讨论】:
【参考方案1】:这是一个实现您的问题的简单示例。 创建一个 udf 并通过概率列,并为新添加的列返回 0 或 1。在一行中使用 WrappedArray 代替 Array、Vector。
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val data = spark.sparkContext.parallelize(Seq(
(Vector(0.78, 0.98, 0.97), 1), (Vector(0.78, 0.96), 2), (Vector(0.78, 0.50), 3)
)).toDF("probability", "id")
data.withColumn("label", label($"probability")).show()
def label = udf((prob: mutable.WrappedArray[Double]) =>
if (prob.max >= 0.95) 1 else 0
)
输出:
+------------------+---+-----+
| probability| id|label|
+------------------+---+-----+
|[0.78, 0.98, 0.97]| 1| 1|
| [0.78, 0.96]| 2| 1|
| [0.78, 0.5]| 3| 0|
+------------------+---+-----+
【讨论】:
【参考方案2】:定义UDF
val findP = udf((label: <type>, prediction: <type>, probability: <type> ) =>
if (label == prediction && vector.toArray.max > 0.95) 1 else 0
)
在 withCoulmn() 中使用 UDF
val output = predictions.withColumn("easy",findP($"lable",$"prediction",$"probability"))
【讨论】:
这行得通。谢谢你!我要补充的一件事是,为了让它发挥作用,我必须找到正确的概率类型。这是一个 DenseVector。看到这个问题:***.com/questions/35855382/…【参考方案3】:使用 udf,例如:
val func = (label: String, prediction: String, vector: Vector) =>
if(label == prediction && vector.toArray.max > 0.95) 1 else 0
val output = predictions
.select($"label", func($"label", $"prediction", $"probability").as("easy"))
【讨论】:
以上是关于向 DataFrame 添加一列,其值为 1,其中预测大于自定义阈值的主要内容,如果未能解决你的问题,请参考以下文章
PySpark 从 TimeStampType 列向 DataFrame 添加一列
在Spark Dataframe中的列列表中添加一列rowums