如何计算 DataFrame 中的移动中位数?

Posted

技术标签:

【中文标题】如何计算 DataFrame 中的移动中位数?【英文标题】:How to calculate moving median in DataFrame? 【发布时间】:2017-05-19 04:36:49 【问题描述】:

有没有办法为 Spark DataFrame 中的属性计算 移动中位数

我希望可以使用窗口函数计算移动中位数(通过使用rowsBetween(0,10) 定义窗口),但没有计算它的功能(类似于averagemean)。

【问题讨论】:

【参考方案1】:

这是我扩展 UserDefinedAggregateFunction 以获得移动中位数的类。

class MyMedian extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction 
  def inputSchema: org.apache.spark.sql.types.StructType =
    org.apache.spark.sql.types.StructType(org.apache.spark.sql.types.StructField("value", org.apache.spark.sql.types.DoubleType) :: Nil)

  def bufferSchema: org.apache.spark.sql.types.StructType = org.apache.spark.sql.types.StructType(
    org.apache.spark.sql.types.StructField("window_list", org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.DoubleType, false)) :: Nil
  )
  def dataType: org.apache.spark.sql.types.DataType = org.apache.spark.sql.types.DoubleType
  def deterministic: Boolean = true
  def initialize(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer): Unit = 
    buffer(0) = new scala.collection.mutable.ArrayBuffer[Double]()
  
  def update(buffer: org.apache.spark.sql.expressions.MutableAggregationBuffer,input: org.apache.spark.sql.Row): Unit = 
    var bufferVal=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).toBuffer
    bufferVal+=input.getAs[Double](0)
    buffer(0) = bufferVal
  
  def merge(buffer1: org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2: org.apache.spark.sql.Row): Unit = 
    buffer1(0) = buffer1.getAs[scala.collection.mutable.ArrayBuffer[Double]](0) ++ buffer2.getAs[scala.collection.mutable.ArrayBuffer[Double]](0)
  
  def evaluate(buffer: org.apache.spark.sql.Row): Any = 
      var sortedWindow=buffer.getAs[scala.collection.mutable.WrappedArray[Double]](0).sorted.toBuffer
      var windowSize=sortedWindow.size
      if(windowSize%2==0)
          var index=windowSize/2
          (sortedWindow(index) + sortedWindow(index-1))/2
      else
          var index=(windowSize+1)/2 - 1
          sortedWindow(index)
      
  

使用上面的 UDAF 示例:

// Create an instance of UDAF MyMedian.
val mm = new MyMedian

var movingMedianDS = dataSet.withColumn("MovingMedian", mm(col("value")).over( Window.partitionBy("GroupId").rowsBetween(-10,10)) )

【讨论】:

如何注册 UDAF 并在 Spark 应用程序中使用它? @Jacek //创建 MyMedian 对象 var winMed=new MyMedian //定义窗口 val wSpec1 = Window.partitionBy….. //在窗口上应用中位数 df.withColumn(,winMed(col ()).over(wSpec1))【参考方案2】:

认为你在这里几乎没有选择。

ntile 窗口函数

我认为ntile(2)(在行窗口上)会给你两个“段”,你可以用它们来计算窗口中的中位数。

引用scaladoc:

ntile(n: Int) 窗口函数:返回有序窗口分区中的ntile组id(从1到n)。例如,如果 n 为 4,则行的第一季度将获得值 1,第二季度将获得 2,第三季度将获得 3,最后一个季度将获得 4。

这相当于 SQL 中的 NTILE 函数。

如果一组中的行数大于另一组中的行数,请从较大的组中选择最大的。

如果组中的行数是偶数,则取每组中的最大值和最小值并计算中位数。

我发现它在Calculating median using the NTILE function 中描述得非常好。

percent_rank 窗口函数

我认为percent_rank 也可能是计算行窗口中位数的一个选项。

引用scaladoc:

percent_rank() 窗口函数:返回窗口分区内行的相对排名(即百分位数)。

这是通过以下方式计算的:

(rank of row in its partition - 1) / (number of rows in the partition - 1)

这相当于 SQL 中的 PERCENT_RANK 函数。

用户定义的聚合函数 (UDAF)

您可以编写一个用户定义的聚合函数 (UDAF) 来计算窗口的中位数。

UDAF 扩展 org.apache.spark.sql.expressions.UserDefinedAggregateFunction,即(引用 scaladoc):

用于实现用户定义聚合函数 (UDAF) 的基类。

幸运的是,UserDefinedUntypedAggregation 示例中有一个自定义 UDAF 的示例实现。

【讨论】:

【参考方案3】:

Spark 2.1+ 中,我们可以使用函数 percentilepercentile_approx 来查找中位数。我们可以在聚合和窗口函数中使用它们。如您所愿,您也可以使用rowsBetween()

使用 PySpark 的示例:

from pyspark.sql import SparkSession, functions as F, Window as W
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
    [(1, 10),
     (1, 20),
     (1, 30),
     (1, 40),
     (1, 50),
     (2, 50)],
    ['c1', 'c2']
)
df = (
    df
    .withColumn(
        'moving_median_1',
        F.expr('percentile(c2, 0.5)').over(W.partitionBy('c1').orderBy('c2')))
    .withColumn(
        'moving_median_2',
        F.expr('percentile(c2, 0.5) over(partition by c1 order by c2)'))
    .withColumn(
        'moving_median_3_rows_1',
        F.expr('percentile(c2, 0.5)').over(W.partitionBy('c1').orderBy('c2').rowsBetween(-2, 0)))
    .withColumn(
        'moving_median_3_rows_2',
        F.expr('percentile(c2, 0.5) over(partition by c1 order by c2 rows between 2 preceding and current row)'))
).show()
#+---+---+---------------+---------------+----------------------+----------------------+
#| c1| c2|moving_median_1|moving_median_2|moving_median_3_rows_1|moving_median_3_rows_2|
#+---+---+---------------+---------------+----------------------+----------------------+
#|  1| 10|           10.0|           10.0|                  10.0|                  10.0|
#|  1| 20|           15.0|           15.0|                  15.0|                  15.0|
#|  1| 30|           20.0|           20.0|                  20.0|                  20.0|
#|  1| 40|           25.0|           25.0|                  30.0|                  30.0|
#|  1| 50|           30.0|           30.0|                  40.0|                  40.0|
#|  2| 50|           50.0|           50.0|                  50.0|                  50.0|
#+---+---+---------------+---------------+----------------------+----------------------+

【讨论】:

以上是关于如何计算 DataFrame 中的移动中位数?的主要内容,如果未能解决你的问题,请参考以下文章

pandas使用groupby函数计算dataframe数据中每个分组的N个数值的滚动中位数值(rolling median)例如,计算某公司的多个店铺每N天(5天)的滚动销售额中位数

R语言vtreat包自动处理dataframe的缺失值使用分组的中位数来标准化数据列中每个数据的值(和中位数表连接并基于中位数进行数据标化)计算数据列的中位数或者均值并进行数据标准化

pandas使用rolling函数计算dataframe指定数据列特定窗口下的滚动中位数(rolling median)自定义指定滚动窗口的大小(window size)

R语言colSums函数rowSums函数colMeans函数rowMeans函数colMedians函数rowMedians计算dataframe行或者列的加和均值中位数实战

pandas使用to_datetime函数将字符串时间数据列转化为时间对象数据列计算dataframe结束时间列和起始时间列的时间差并计算时间差的中位数(median)

T-SQL 中的移动中位数、众数