Pyspark 滚动平均值从第一行开始

Posted

技术标签:

【中文标题】Pyspark 滚动平均值从第一行开始【英文标题】:Pyspark Rolling Average starting at first row 【发布时间】:2020-08-16 17:25:35 【问题描述】:

我正在尝试计算 Pyspark 中的滚动平均值。我让它工作,但它的行为似乎与我预期的不同。滚动平均值从第一行开始。

例如:

columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
       ('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]

df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df_test.withColumn('rolling_average', f.avg('value').over(win)).show()

+-----+---+------+------------------+
|month|day| value|   rolling_average|
+-----+---+------+------------------+
|  JAN| 01| 20000|           20000.0|
|  JAN| 02| 40000|           30000.0|
|  JAN| 03| 30000|           30000.0|
|  JAN| 04| 25000|31666.666666666668|
|  JAN| 05|  5000|           20000.0|
|  JAN| 06| 15000|           15000.0|
|  FEB| 01| 10000|           10000.0|
|  FEB| 02| 50000|           30000.0|
|  FEB| 03|100000|53333.333333333336|
|  FEB| 04| 60000|           70000.0|
|  FEB| 05|  1000|53666.666666666664|
|  FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+

这将更符合我的预期。有没有办法得到这种行为?

+-----+---+------+------------------+
|month|day| value|   rolling_average|
+-----+---+------+------------------+
|  JAN| 01| 20000|              null|
|  JAN| 02| 40000|              null|
|  JAN| 03| 30000|           30000.0|
|  JAN| 04| 25000|31666.666666666668|
|  JAN| 05|  5000|           20000.0|
|  JAN| 06| 15000|           15000.0|
|  FEB| 01| 10000|              null|
|  FEB| 02| 50000|              null|
|  FEB| 03|100000|53333.333333333336|
|  FEB| 04| 60000|           70000.0|
|  FEB| 05|  1000|53666.666666666664|
|  FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+

默认行为的问题是我需要另一列来跟踪延迟应该从哪里开始。

【问题讨论】:

【参考方案1】:

尝试使用 row_number() 窗口函数,然后使用 when+otherwise 语句替换 null。

要更改 lag start,然后更改 when 语句 col("rn") <= <value> 的值。

Example:

columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
       ('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]

df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)

win1 = Window.partitionBy('month').orderBy('day')

df_test.withColumn('rolling_average', f.avg('value').over(win)).\
withColumn("rn",row_number().over(win1)).\
withColumn("rolling_average",when(col("rn") <= 2 ,lit(None)).\
otherwise(col("rolling_average"))).\
drop("rn").\
show()
#+-----+---+------+------------------+
#|month|day| value|   rolling_average|
#+-----+---+------+------------------+
#|  FEB| 01| 10000|              null|
#|  FEB| 02| 50000|              null|
#|  FEB| 03|100000|53333.333333333336|
#|  FEB| 04| 60000|           70000.0|
#|  FEB| 05|  1000|53666.666666666664|
#|  FEB| 06| 10000|23666.666666666668|
#|  JAN| 01| 20000|              null|
#|  JAN| 02| 40000|              null|
#|  JAN| 03| 30000|           30000.0|
#|  JAN| 04| 25000|31666.666666666668|
#|  JAN| 05|  5000|           20000.0|
#|  JAN| 06| 15000|           15000.0|
#+-----+---+------+------------------+

【讨论】:

谢谢。这比我对 row_number() 的想法更有效,而且更优雅。【参考方案2】:

@484 的更多简化版本。

import pyspark.sql.functions as f
from pyspark.sql import Window

w1 = Window.partitionBy('month').orderBy('day')
w2 = Window.partitionBy('month').orderBy('day').rowsBetween(-2, 0)

df.withColumn("rolling_average", f.when(f.row_number().over(w1) > f.lit(2), f.avg('value').over(w2))).show(10, False)

附言请不要将此标记为答案:)

【讨论】:

以上是关于Pyspark 滚动平均值从第一行开始的主要内容,如果未能解决你的问题,请参考以下文章

使用 PySpark 而不使用窗口对来自 Kafka 的流数据执行滚动平均

pySpark - 在滚动窗口中获取最大值行

pyspark 时间序列数据的高性能滚动/窗口聚合

如何在 PySpark 中找到数组数组的平均值

在pyspark中用平均值填充缺失值

Pyspark:添加具有 groupby 平均值的列