Pyspark 窗口函数,具有对旅行者数量进行取整的条件

Posted

技术标签:

【中文标题】Pyspark 窗口函数,具有对旅行者数量进行取整的条件【英文标题】:Pyspark window function with conditions to round number of travelers 【发布时间】:2021-10-06 09:59:15 【问题描述】:

我正在使用 Pyspark,我想创建一个执行以下操作的函数:

给定描述火车用户交易的数据:

+----+----------+--------+-----+
|date|total_trav|num_trav|order|
+----+----------+--------+-----+
|   1|         9|     2.7|    1|
|   1|         9|     1.3|    2|
|   1|         9|     1.3|    3|
|   1|         9|     1.3|    4|
|   1|         9|     1.2|    5|
|   1|         9|     1.1|    6|
|   2|         9|     2.7|    1|
|   2|         9|     1.3|    2|
|   2|         9|     1.3|    3|
|   2|         9|     1.3|    4|
|   2|         9|     1.2|    5|
|   2|         9|     1.1|    6|
+----+----------+--------+-----+

我想根据order 列中给出的顺序对num_trav 列的数字进行四舍五入,同时按date 分组以获得trav_res 列。 它背后的逻辑是这样的:

我们将数据按date分组 对于每个分组数据(其中date=1date=2),我们必须始终将数字舍入到上限(ceil(num_trav))(它们的值无关紧要,始终舍入到上限)。但考虑到我们按组 (total_trav) 的旅客人数上限,在这种情况下,两组的旅客人数均为 9。 这是order 列的位置。您需要按该列给出的顺序开始四舍五入,并检查您为该组留下的旅行者数量。

例如,让我们考虑这个结果数据框,看看trav_res 列是如何形成的:

+----+----------+--------+-----+--------+
|date|total_trav|num_trav|order|trav_res|
+----+----------+--------+-----+--------+
|   1|         9|     2.7|    1|       3|
|   1|         9|     1.3|    2|       2|
|   1|         9|     1.3|    3|       2|
|   1|         9|     1.3|    4|       2|
|   1|         9|     1.2|    5|       0|
|   1|         9|     1.1|    6|       0|
|   2|         9|     2.7|    1|       3|
|   2|         9|     1.3|    2|       2|
|   2|         9|     1.3|    3|       2|
|   2|         9|     1.3|    4|       2|
|   2|         9|     1.2|    5|       0|
|   2|         9|     1.1|    6|       0|
+----+----------+--------+-----+--------+

在上面的示例中,当您按日期分组时,您将有 2 个组,最大旅行者数量为 9(total_trav 列)。 例如,对于第 1 组,你将开始将 num_trav=2.7 舍入为 3(trav_res 列),然后将 num_trav=1.3 舍入为 2,然后将 num_trav=1.3 舍入为 2,将 num_trav=1.3 舍入为 2(这是按照给定的顺序) ),然后对于下一个您没有剩余的旅行者,所以他们拥有的人数并不重要,因为没有剩余的旅行者,所以在这两种情况下他们都会得到trav_res=0

我已经尝试了一些 udf 函数,但你似乎没有完成这项工作。

【问题讨论】:

目前尚不清楚您要实现什么目标以及在第二个数据框中获取 trav_res 列的逻辑是什么。 @AnnaK。我编辑了问题并解释了示例以使其更清楚。如果还不清楚,请告诉我。谢谢 【参考方案1】:

您可以先将 F.ceil 应用到 num_trav 中的所有行,然后根据上限值创建 cumsum 列,然后在 cumsum 超过 total_trav 时将上限值设置为零,如下代码所示

# create dataframe
import pyspark.sql.functions as F
from pyspark.sql import Window

data = [(1, 9, 2.7, 1),
        (1, 9, 1.3, 2),
        (1, 9, 1.3, 3),
        (1, 9, 1.3, 4),
        (1, 9, 1.2, 5),
        (1, 9, 1.1, 6),
        (2, 9, 2.7, 1),
        (2, 9, 1.3, 2),
        (2, 9, 1.3, 3),
        (2, 9, 1.3, 4),
        (2, 9, 1.2, 5),
        (2, 9, 1.1, 6)]

df = spark.createDataFrame(data, schema=["date", "total_trav", "num_trav", "order"])

# create ceiling column
df = df.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil"))
               .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

【讨论】:

几乎是正确的。唯一缺少的是trav_res 中的值的总和必须与total_trav 相同。因此,如果您总共有 6 位旅行者,则需要安排这 6 位,而不是更少。我在您的原始代码中添加了这个小修改,现在我的工作就像一个魅力。非常感谢!!! (下面发布的代码修改作为答案,再次感谢,这只是要添加的一点点,但您的答案几乎是完美的)。【参考方案2】:

该解决方案基于@AnnaK。回答,再加一点。 这样就考虑到了必须使用的旅客总数 (total_trav),而不是更多,也不是更少。

# create ceiling column
df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil")
                     ).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
                      1)
              .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

【讨论】:

以上是关于Pyspark 窗口函数,具有对旅行者数量进行取整的条件的主要内容,如果未能解决你的问题,请参考以下文章

带有窗口函数的 PySpark 数据偏度

PySpark - 获取具有相同值的数组元素的数量

PySpark 中的窗口函数和条件过滤器

如何在 pyspark 中对需要在聚合中聚合的分组数据应用窗口函数?

PySpark UDF 无法识别参数数量

Pyspark 从具有不同列的行/数据创建 DataFrame