Pyspark SQL/SQL 中的窗口和聚合函数

Posted

技术标签:

【中文标题】Pyspark SQL/SQL 中的窗口和聚合函数【英文标题】:Window & Aggregate functions in Pyspark SQL/SQL 【发布时间】:2021-10-16 23:00:52 【问题描述】:

@Vaebhav 回答后意识到问题设置不正确。 因此用他的代码 sn-p 编辑它。

我有下表

from pyspark.sql.types import IntegerType,TimestampType,DoubleType

input_str = """
4219,2018-01-01 08:10:00,3.0,50.78,
4216,2018-01-02 08:01:00,5.0,100.84,
4217,2018-01-02 20:00:00,4.0,800.49,
4139,2018-01-03 11:05:00,1.0,400.0,
4170,2018-01-03 09:10:00,2.0,100.0,
4029,2018-01-06 09:06:00,6.0,300.55,
4029,2018-01-06 09:16:00,2.0,310.55,
4217,2018-01-06 09:36:00,5.0,307.55,
1139,2018-01-21 11:05:00,1.0,400.0,
2170,2018-01-21 09:10:00,2.0,100.0,
4218,2018-02-06 09:36:00,5.0,307.55,
4218,2018-02-06 09:36:00,5.0,307.55
""".split(",")

input_values = list(map(lambda x: x.strip() if x.strip() != '' else None, input_str))
cols = list(map(lambda x: x.strip() if x.strip() != 'null' else None, "customer_id,timestamp,quantity,price".split(',')))
n = len(input_values)
n_cols = 4
input_list = [tuple(input_values[i:i+n_cols]) for i in range(0,n,n_cols)]
sparkDF = sqlContext.createDataFrame(input_list,cols)
sparkDF = sparkDF.withColumn('customer_id',F.col('customer_id').cast(IntegerType()))\
                 .withColumn('timestamp',F.col('timestamp').cast(TimestampType()))\
                 .withColumn('quantity',F.col('quantity').cast(IntegerType()))\
                 .withColumn('price',F.col('price').cast(DoubleType()))

我想按如下方式计算聚合门:

trxn_date unique_cust_visits next_7_day_visits next_30_day_visits
2018-01-01 1 7 9
2018-01-02 2 6 8
2018-01-03 2 4 6
2018-01-06 2 2 4
2018-01-21 2 2 3
2018-02-06 1 1 1

在哪里

trxn_date 是时间戳列中的日期, daily_cust_visits 是唯一的客户计数, next_7_day_visits 是基于 7 天滚动窗口的客户计数。 next_30_day_visits 是基于 30 天滚动窗口的客户计数。

我想将代码编写为单个 SQL 查询。

【问题讨论】:

【参考方案1】:

您可以通过使用ROW 而不是RANGE Frame Type 来实现这一点,可以在here 找到一个很好的解释

ROW - 基于与当前输入行位置的物理偏移量

RANGE - 基于与当前输入行位置的逻辑偏移量

此外,在您的实现中,PARTITION BY 子句将是多余的,因为它不会为前瞻创建所需的 Frames

数据准备

input_str = """
4219,2018-01-02 08:10:00,3.0,50.78,
4216,2018-01-02 08:01:00,5.0,100.84,
4217,2018-01-02 20:00:00,4.0,800.49,
4139,2018-01-03 11:05:00,1.0,400.0,
4170,2018-01-03 09:10:00,2.0,100.0,
4029,2018-01-06 09:06:00,6.0,300.55,
4029,2018-01-06 09:16:00,2.0,310.55,
4217,2018-01-06 09:36:00,5.0,307.55
""".split(",")

input_values = list(map(lambda x: x.strip() if x.strip() != '' else None, input_str))

cols = list(map(lambda x: x.strip() if x.strip() != 'null' else None, "customer_id  timestamp   quantity    price".split('\t')))
        
n = len(input_values)
n_cols = 4

input_list = [tuple(input_values[i:i+n_cols]) for i in range(0,n,n_cols)]

sparkDF = sql.createDataFrame(input_list,cols)

sparkDF = sparkDF.withColumn('customer_id',F.col('customer_id').cast(IntegerType()))\
                 .withColumn('timestamp',F.col('timestamp').cast(TimestampType()))\
                 .withColumn('quantity',F.col('quantity').cast(IntegerType()))\
                 .withColumn('price',F.col('price').cast(DoubleType()))

sparkDF.show()

+-----------+-------------------+--------+------+
|customer_id|          timestamp|quantity| price|
+-----------+-------------------+--------+------+
|       4219|2018-01-02 08:10:00|       3| 50.78|
|       4216|2018-01-02 08:01:00|       5|100.84|
|       4217|2018-01-02 20:00:00|       4|800.49|
|       4139|2018-01-03 11:05:00|       1| 400.0|
|       4170|2018-01-03 09:10:00|       2| 100.0|
|       4029|2018-01-06 09:06:00|       6|300.55|
|       4029|2018-01-06 09:16:00|       2|310.55|
|       4217|2018-01-06 09:36:00|       5|307.55|
+-----------+-------------------+--------+------+

窗口聚合

sparkDF.createOrReplaceTempView("transactions")

sql.sql("""
        SELECT 
            TO_DATE(timestamp) as trxn_date
            ,COUNT(DISTINCT customer_id) as unique_cust_visits
            ,SUM(COUNT(DISTINCT customer_id)) OVER (
                        ORDER BY 'timestamp'
                        ROWS BETWEEN CURRENT ROW AND 7 FOLLOWING
            ) as next_7_day_visits
        FROM transactions
        GROUP BY 1
""").show()

+----------+------------------+-----------------+
| trxn_date|unique_cust_visits|next_7_day_visits|
+----------+------------------+-----------------+
|2018-01-02|                 3|                7|
|2018-01-03|                 2|                4|
|2018-01-06|                 2|                2|
+----------+------------------+-----------------+

【讨论】:

ROW 范围将导致连续行的聚合,但是由于我正在处理事务数据,因此 RANGE 更适合。我已经编辑了我的问题以提供一个更好的例子。但是,您的回答确实引导我找到正确的解决方案:) 我看到你的回答,很高兴它现在可以工作,如果它对你有帮助,请投票并接受答案【参考方案2】:

基于@Vaebhav 的回答,在这种情况下所需的查询是

sqlContext.sql("""
        SELECT 
            TO_DATE(timestamp) as trxn_date
            ,COUNT(DISTINCT customer_id) as unique_cust_visits
            ,SUM(COUNT(DISTINCT customer_id)) OVER (
                        ORDER BY CAST(TO_DATE(timestamp) AS TIMESTAMP) DESC
                        RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
            ) as next_7_day_visits
            ,SUM(COUNT(DISTINCT customer_id)) OVER (
                        ORDER BY CAST(TO_DATE(timestamp) AS TIMESTAMP) DESC
                        RANGE BETWEEN INTERVAL 30 DAYS PRECEDING AND CURRENT ROW
            ) as next_30_day_visits
        FROM transactions
        GROUP BY 1
        ORDER by trxn_date
""").show()
trxn_date unique_cust_visits next_7_day_visits next_30_day_visits
2018-01-01 1 7 9
2018-01-02 2 6 8
2018-01-03 2 4 6
2018-01-06 2 2 4
2018-01-21 2 2 3
2018-02-06 1 1 1

【讨论】:

以上是关于Pyspark SQL/SQL 中的窗口和聚合函数的主要内容,如果未能解决你的问题,请参考以下文章

如何在 pyspark 中对需要在聚合中聚合的分组数据应用窗口函数?

SQL 移动聚合

Hive sql及窗口函数

Pyspark 将列列表转换为聚合函数

具有组间聚合结果的 Pyspark 窗口

[SQL] SQL 基础知识梳理- 聚合和排序