没有 UDF 的 Spark 数据集的加权平均值

Posted

技术标签:

【中文标题】没有 UDF 的 Spark 数据集的加权平均值【英文标题】:Weighted average with Spark Datasets without UDF 【发布时间】:2017-08-10 19:23:28 【问题描述】:

虽然有人已经问过关于计算 Weighted Average in Spark 的问题,但在这个问题中,我问的是使用 Datasets/DataFrames 而不是 RDD。

如何在 Spark 中计算加权平均值?我有两列:计数和以前的平均值:

case class Stat(name:String, count: Int, average: Double)
val statset = spark.createDataset(Seq(Stat("NY", 1,5.0),
                           Stat("NY",2,1.5),
                           Stat("LA",12,1.0),
                           Stat("LA",15,3.0)))

我希望能够像这样计算加权平均值:

display(statset.groupBy($"name").agg(sum($"count").as("count"),
                    weightedAverage($"count",$"average").as("average")))

可以使用 UDF 接近:

val weightedAverage = udf(
  (row:Row)=>
    val counts = row.getAs[WrappedArray[Int]](0)
    val averages = row.getAs[WrappedArray[Double]](1)
    val (count,total) = (counts zip averages).foldLeft((0,0.0))
      case((cumcount:Int,cumtotal:Double),(newcount:Int,newaverage:Double))=>(cumcount+newcount,cumtotal+newcount*newaverage)
    (total/count)  // Tested by returning count here and then extracting. Got same result as sum.
  
)

display(statset.groupBy($"name").agg(sum($"count").as("count"),
                    weightedAverage(struct(collect_list($"count"),
                                    collect_list($"average"))).as("average")))

(感谢Passing a list of tuples as a parameter to a spark udf in scala 提供的帮助,帮助编写本文)

新手:使用这些导入:

import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection.mutable.WrappedArray

有没有办法使用内置列函数而不是 UDF 来完成此任务? UDF 感觉很笨重,如果数字变大,您必须将 Int 转换为 Long。

【问题讨论】:

【参考方案1】:

看起来你可以分两次完成:

val totalCount = statset.select(sum($"count")).collect.head.getLong(0)

statset.select(lit(totalCount) as "count", sum($"average" * $"count" / lit(totalCount)) as "average").show

或者,包括你刚刚添加的 groupBy:

display(statset.groupBy($"name").agg(sum($"count").as("count"),
                    sum($"count"*$"average").as("total"))
               .select($"name",$"count",($"total"/$"count")))

【讨论】:

我会将总数添加为第二个聚合中的另一列,然后在最后进行除法。第二遍需要检查的数据要少得多。 @MichelLemay:谢谢!这正是我需要慢慢思考的东西。我建议对您的答案进行修改,该修改也适用于 groupBy。 如果对您有帮助,您可以接受@JosiahYoder 的答案

以上是关于没有 UDF 的 Spark 数据集的加权平均值的主要内容,如果未能解决你的问题,请参考以下文章

统计 Spark 中 UDF 的调用次数

group的加权平均值不等于pandas groupby中的总平均值

C++ 以费波纳茨数列为权重的加权均值计算方法 wMA

“平均值”是啥意思?

使用熊猫/数据框计算加权平均值

如何平滑和绘制 x 与 y 的加权平均值,由 x 加权?