如何在 Spark 中找到分组数据的准确中位数

Posted

技术标签:

【中文标题】如何在 Spark 中找到分组数据的准确中位数【英文标题】:How to find exact median for grouped data in Spark 【发布时间】:2017-01-02 17:56:02 【问题描述】:

我需要使用 Scala 计算 Spark 中 Double 数据类型的分组数据集的精确中位数。

与类似查询不同:Find median in spark SQL for multiple double datatype columns。这个问题是关于分组数据的查找数据,而另一个问题是关于在 RDD 级别上查找中位数。

这是我的示例数据

scala> sqlContext.sql("select * from test").show()

+---+---+
| id|num|
+---+---+
|  A|0.0|
|  A|1.0|
|  A|1.0|
|  A|1.0|
|  A|0.0|
|  A|1.0|
|  B|0.0|
|  B|1.0|
|  B|1.0|
+---+---+

预期答案:

+--------+
| Median |
+--------+
|   1    |
|   1    |
+--------+

我尝试了以下选项,但没有运气:

1) Hive 函数百分位数,它仅适用于 BigInt。

2) Hive 函数 percentile_approx,但它没有按预期工作(返回 0.25 vs 1)。

scala> sqlContext.sql("select percentile_approx(num, 0.5) from test group by id").show()

+----+
| _c0|
+----+
|0.25|
|0.25|
+----+

【问题讨论】:

【参考方案1】:

最简单的方法(需要 Spark 2.0.1+ 而不是精确的中位数)

正如 cmets 中提到的第一个问题 Find median in Spark SQL for double datatype columns 中所述,我们可以使用 percentile_approx 来计算 Spark 2.0.1+ 的中值。要将其应用于 Apache Spark 中的分组数据,查询应如下所示:

val df = Seq(("A", 0.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)).toDF("id", "num")
df.createOrReplaceTempView("df")
spark.sql("select id, percentile_approx(num, 0.5) as median from df group by id order by id").show()

输出为:

+---+------+
| id|median|
+---+------+
|  A|   1.0|
|  B|   1.0|
+---+------+

也就是说,这是一个近似值(而不是每个问题的精确中位数)。

计算分组数据的准确中位数

有多种方法,所以我相信 SO 中的其他人可以提供更好或更有效的示例。但这里有一段代码 sn-p 计算 Spark 中分组数据的中位数(在 Spark 1.6 和 Spark 2.1 中验证):

import org.apache.spark.SparkContext._

val rdd: RDD[(String, Double)] = sc.parallelize(Seq(("A", 1.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 0.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)))

// Scala median function
def median(inputList: List[Double]): Double = 
  val count = inputList.size
  if (count % 2 == 0) 
    val l = count / 2 - 1
    val r = l + 1
    (inputList(l) + inputList(r)).toDouble / 2
   else
    inputList(count / 2).toDouble


// Sort the values
val setRDD = rdd.groupByKey()
val sortedListRDD = setRDD.mapValues(_.toList.sorted)

// Output DataFrame of id and median
sortedListRDD.map(m => 
  (m._1, median(m._2))
).toDF("id", "median_of_num").show()

输出为:

+---+-------------+
| id|median_of_num|
+---+-------------+
|  A|          1.0|
|  B|          1.0|
+---+-------------+

我应该指出一些警告,因为这可能不是最有效的实现:

目前使用的groupByKey 性能不是很好。您可能希望将其更改为 reduceByKey(更多信息请访问 Avoid GroupByKey) 使用 Scala 函数计算 median

这种方法应该适用于少量数据,但如果每个键都有数百万行,建议使用 Spark 2.0.1+ 并使用 percentile_approx 方法。

【讨论】:

并不是我认为的中位数在分布式环境中要复杂得多。因此 Spark 2.0 中使用 percentile_approx 的原因。 我们可以避免创建临时视图吗?【参考方案2】:

这是我在 SPARK 中的 PERCENTILE_COUNT 函数版本。这可用于查找 Dataframe 中分组数据的中值。希望它可以帮助某人。随时提供您的建议以改进解决方案。

val PERCENTILEFLOOR = udf((maxrank: Integer, percentile: Double) => scala.math.floor(1 + (percentile * (maxrank - 1))))

  val PERCENTILECEIL = udf((maxrank: Integer, percentile: Double) => scala.math.ceil(1 + (percentile * (maxrank - 1))))

  val PERCENTILECALC = udf((maxrank: Integer, percentile: Double, floorVal: Double, ceilVal: Double, floorNum: Double, ceilNum: Double)
=> 
    if (ceilNum == floorNum) 
      floorVal
     else 
      val RN = (1 + (percentile * (maxrank - 1)))
      ((ceilNum - RN) * floorVal) + ((RN - floorNum) * ceilVal)
       )



/**    * The result of PERCENTILE_CONT is computed by linear interpolation between values after ordering them.    * Using the percentile value (P) and the number of rows (N) in the aggregation group,    * we compute the row number we are interested in after ordering the rows with respect to the sort specification.    * This row number (RN) is computed according to the formula RN = (1+ (P*(N-1)).    * The final result of the aggregate function is computed by linear interpolation between the values from rows at row numbers   
* CRN = CEILING(RN) and FRN = FLOOR(RN).    *    * The final result will be:    *    * If (CRN = FRN = RN) then the result is    * (value of expression from row at RN)    * Otherwise the result is    * (CRN - RN) * (value of expression for row at FRN) +    * (RN - FRN) * (value of expression for row at CRN)    *    * Parameter details    *    * @inputDF - Dataframe for computation    * @medianCol - Column for which percentile to be calculated    * @grouplist - Group list for dataframe before sorting    * @percentile - numeric value between 0 and 1 to express the percentile to be calculated    *    */

  def percentile_count(inputDF: DataFrame, medianCol: String, groupList: List[String], percentile: Double): DataFrame = 
    val orderList = List(medianCol)

    val wSpec3 = Window.partitionBy(groupList.head, groupList.tail: _*).orderBy(orderList.head, orderList.tail: _*)
    //   Group, sort and rank the DF
    val rankedDF = inputDF.withColumn("rank", row_number().over(wSpec3))

    // Find the maximum for each group 
    val groupedMaxDF = rankedDF.groupBy(groupList.head, groupList.tail: _*).agg(max("rank").as("maxval"))

    // CRN calculation
    val ceilNumDF = groupedMaxDF.withColumn("rank", PERCENTILECEIL(groupedMaxDF("maxval"), lit(percentile))).drop("maxval")

    // FRN calculation
    val floorNumDF = groupedMaxDF.withColumn("rank", PERCENTILEFLOOR(groupedMaxDF("maxval"), lit(percentile)))

    val ntileGroup = "rank" :: groupList

    //Get the values for the CRN and FRN 
    val floorDF = floorNumDF.join(rankedDF, ntileGroup).withColumnRenamed("rank", "floorNum").withColumnRenamed(medianCol, "floorVal")
    val ceilDF = ceilNumDF.join(rankedDF, ntileGroup).withColumnRenamed("rank", "ceilNum").withColumnRenamed(medianCol, "ceilVal")

    //Get both the values for CRN and FRN in same row
    val resultDF = floorDF.join(ceilDF, groupList)

    val finalList = "median_" + medianCol :: groupList

    // Calculate the median using the UDF PERCENTILECALC and returns the DF
    resultDF.withColumn("median_" + medianCol, PERCENTILECALC(resultDF("maxval"), lit(percentile), resultDF("floorVal"), resultDF("ceilVal"), resultDF("floorNum"), resultDF("ceilNum"))).select(finalList.head, finalList.tail: _*)

  

【讨论】:

【参考方案3】:

您可以尝试此解决方案以获得精确的中位数。我在这里描述了 spark sql 解决方案gist.github。 为了计算准确的中位数,我将 row_number() 和 count() 函数与窗口函数结合使用。

val data1 = Array( ("a", 0), ("a", 1), ("a", 1), ("a", 1), ("a", 0), ("a", 1))
val data2 = Array( ("b", 0), ("b", 1), ("b", 1))
val union = data1.union(data2)
val df = sc.parallelize(union).toDF("key", "val")
df.cache.createOrReplaceTempView("kvTable")
spark.sql("SET spark.sql.shuffle.partitions=2")


var ds = spark.sql("""
   SELECT key, avg(val) as median
FROM ( SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2 
        FROM ( 
            SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
            FROM kvTable
         ) s
    ) r
WHERE rN BETWEEN m1 and m2
GROUP BY key 
""")

Spark 有效地执行和优化此查询,因为它重用了数据分区。

scala> ds.show
+---+------+
|key|median|
+---+------+
|  a|   1.0|
|  b|   1.0|
+---+------+

【讨论】:

【参考方案4】:

在 Spark 2.4 中添加了高阶函数 element_at。我们可以使用 Window 函数,或者 groupBy 然后加入。

样本数据

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

case class Salary(depName: String, empNo: Long, salary: Long)
val empsalary = Seq(
  Salary("sales", 1, 5000),
  Salary("personnel", 2, 3900),
  Salary("sales", 3, 4800),
  Salary("sales", 4, 4800),
  Salary("personnel", 5, 3500),
  Salary("develop", 7, 4200),
  Salary("develop", 8, 6000),
  Salary("develop", 9, 4500),
  Salary("develop", 10, 5200),
  Salary("develop", 11, 5200)).toDS

带窗口功能

val byDepName = Window.partitionBy('depName).orderBy('salary)
val df = empsalary.withColumn(
  "salaries", collect_list('salary) over byDepName).withColumn(
  "median_salary", element_at('salaries, (size('salaries)/2 + 1).cast("int")))

df.show(false)

与 groupBy 然后加入

val dfMedian = empsalary.groupBy("depName").agg(
  sort_array(collect_list('salary)).as("salaries")).select(
  'depName, 
  element_at('salaries, (size('salaries)/2 + 1).cast("int")).as("median_salary"))
empsalary.join(dfMedian, "depName").show(false)

【讨论】:

【参考方案5】:

如果您不想使用 spark-sql(就像我一样),您可以使用 cume_dist 函数。

请看下面的例子:

import org.apache.spark.sql.functions => F
import org.apache.spark.sql.expressions.Window
val df = (1 to 10).toSeq.toDF
val win = Window.
    partitionBy(F.col("value")).
    orderBy(F.col("value")).
    rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show

结果:

+-----+---+
|value|  c|
+-----+---+
|    1|0.1|
|    2|0.2|
|    3|0.3|
|    4|0.4|
|    5|0.5|
|    6|0.6|
|    7|0.7|
|    8|0.8|
|    9|0.9|
|   10|1.0|
+-----+---+

中位数是 df("c") 等于 0.5 的值。 我希望它会有所帮助,Elior。

【讨论】:

相同的代码会导致不同的输出,所有 c 值都是 1.0。 嗨@ErkanŞirin,不知道你的意思。窗口规范:rangeBetween(Window.unboundedPreceding, Window.currentRow) 应该注意列 c 上的运行总和 不,我只是说相同的代码会导致不同的输出。在我的情况下 c cloumn 不是 0.1, 0.2 ...but 1.0, 1.0,1.0, 1.0【参考方案6】:

只是添加到 Elior 的回答和回应 Erkan,每列的输出为 1.0 的原因是 partitionBy(F.col("value")) 将数据分区为每个分区的一行,这样当window 计算cume_dist 它对单个值进行计算,结果为 1.0。

从窗口操作中删除 partitionBy(F.col("value")) 会产生预期的分位数。


Elior 的回答开始


如果您不想使用 spark-sql(像我一样),可以使用 cume_dist 函数。 请参见下面的示例:

import org.apache.spark.sql.functions => F
import org.apache.spark.sql.expressions.Window
val df = (1 to 10).toSeq.toDF
val win = Window.
    partitionBy(F.col("value")).    //Remove this line
    orderBy(F.col("value")).
    rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show

结果:

+-----+---+
|value|  c|
+-----+---+
|    1|0.1|
|    2|0.2|
|    3|0.3|
|    4|0.4|
|    5|0.5|
|    6|0.6|
|    7|0.7|
|    8|0.8|
|    9|0.9|
|   10|1.0|
+-----+---+

中位数是 df("c") 等于 0.5 的值。我希望它会有所帮助,Elior。


Elior 的回答结束


没有 partitionBy 定义的窗口:

val win = Window.
    orderBy(F.col("value")).
    rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show

【讨论】:

以上是关于如何在 Spark 中找到分组数据的准确中位数的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Spark SQL 中找到分组向量列的平均值?

如何使用 Spark 查找中位数和分位数

如何从用户的输入中找到平均值和中位数? [复制]

如何在 java 中使用 Apache spark 计算中位数和众数?

如何在 Spark 数据框中进行分组和聚合后过滤?

您将如何限制 Spark 中每个分组键要处理的记录数? (对于倾斜的数据)