Pyspark 窗口函数,具有对旅行者数量进行取整的条件
Posted
技术标签:
【中文标题】Pyspark 窗口函数,具有对旅行者数量进行取整的条件【英文标题】:Pyspark window function with conditions to round number of travelers 【发布时间】:2021-10-06 09:59:15 【问题描述】:我正在使用 Pyspark,我想创建一个执行以下操作的函数:
给定描述火车用户交易的数据:
+----+----------+--------+-----+
|date|total_trav|num_trav|order|
+----+----------+--------+-----+
| 1| 9| 2.7| 1|
| 1| 9| 1.3| 2|
| 1| 9| 1.3| 3|
| 1| 9| 1.3| 4|
| 1| 9| 1.2| 5|
| 1| 9| 1.1| 6|
| 2| 9| 2.7| 1|
| 2| 9| 1.3| 2|
| 2| 9| 1.3| 3|
| 2| 9| 1.3| 4|
| 2| 9| 1.2| 5|
| 2| 9| 1.1| 6|
+----+----------+--------+-----+
我想根据order
列中给出的顺序对num_trav
列的数字进行四舍五入,同时按date
分组以获得trav_res
列。
它背后的逻辑是这样的:
date
分组
对于每个分组数据(其中date=1
和date=2
),我们必须始终将数字舍入到上限(ceil(num_trav)
)(它们的值无关紧要,始终舍入到上限)。但考虑到我们按组 (total_trav
) 的旅客人数上限,在这种情况下,两组的旅客人数均为 9。
这是order
列的位置。您需要按该列给出的顺序开始四舍五入,并检查您为该组留下的旅行者数量。
例如,让我们考虑这个结果数据框,看看trav_res
列是如何形成的:
+----+----------+--------+-----+--------+
|date|total_trav|num_trav|order|trav_res|
+----+----------+--------+-----+--------+
| 1| 9| 2.7| 1| 3|
| 1| 9| 1.3| 2| 2|
| 1| 9| 1.3| 3| 2|
| 1| 9| 1.3| 4| 2|
| 1| 9| 1.2| 5| 0|
| 1| 9| 1.1| 6| 0|
| 2| 9| 2.7| 1| 3|
| 2| 9| 1.3| 2| 2|
| 2| 9| 1.3| 3| 2|
| 2| 9| 1.3| 4| 2|
| 2| 9| 1.2| 5| 0|
| 2| 9| 1.1| 6| 0|
+----+----------+--------+-----+--------+
在上面的示例中,当您按日期分组时,您将有 2 个组,最大旅行者数量为 9(total_trav
列)。
例如,对于第 1 组,你将开始将 num_trav=2.7
舍入为 3(trav_res
列),然后将 num_trav=1.3
舍入为 2,然后将 num_trav=1.3
舍入为 2,将 num_trav=1.3
舍入为 2(这是按照给定的顺序) ),然后对于下一个您没有剩余的旅行者,所以他们拥有的人数并不重要,因为没有剩余的旅行者,所以在这两种情况下他们都会得到trav_res=0
。
我已经尝试了一些 udf 函数,但你似乎没有完成这项工作。
【问题讨论】:
目前尚不清楚您要实现什么目标以及在第二个数据框中获取 trav_res 列的逻辑是什么。 @AnnaK。我编辑了问题并解释了示例以使其更清楚。如果还不清楚,请告诉我。谢谢 【参考方案1】:您可以先将 F.ceil 应用到 num_trav 中的所有行,然后根据上限值创建 cumsum 列,然后在 cumsum 超过 total_trav 时将上限值设置为零,如下代码所示
# create dataframe
import pyspark.sql.functions as F
from pyspark.sql import Window
data = [(1, 9, 2.7, 1),
(1, 9, 1.3, 2),
(1, 9, 1.3, 3),
(1, 9, 1.3, 4),
(1, 9, 1.2, 5),
(1, 9, 1.1, 6),
(2, 9, 2.7, 1),
(2, 9, 1.3, 2),
(2, 9, 1.3, 3),
(2, 9, 1.3, 4),
(2, 9, 1.2, 5),
(2, 9, 1.1, 6)]
df = spark.createDataFrame(data, schema=["date", "total_trav", "num_trav", "order"])
# create ceiling column
df = df.withColumn("num_trav_ceil", F.ceil("num_trav"))
# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))
# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
.withColumn("trav_res",
F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"),
F.col("num_trav_ceil"))
.otherwise(0))
.select("date", "total_trav", "num_trav", "order", "trav_res"))
【讨论】:
几乎是正确的。唯一缺少的是trav_res
中的值的总和必须与total_trav
相同。因此,如果您总共有 6 位旅行者,则需要安排这 6 位,而不是更少。我在您的原始代码中添加了这个小修改,现在我的工作就像一个魅力。非常感谢!!! (下面发布的代码修改作为答案,再次感谢,这只是要添加的一点点,但您的答案几乎是完美的)。【参考方案2】:
该解决方案基于@AnnaK。回答,再加一点。 这样就考虑到了必须使用的旅客总数 (total_trav),而不是更多,也不是更少。
# create ceiling column
df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))
# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))
# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
.withColumn("trav_res",
F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"),
F.col("num_trav_ceil")
).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
1)
.otherwise(0))
.select("date", "total_trav", "num_trav", "order", "trav_res"))
【讨论】:
以上是关于Pyspark 窗口函数,具有对旅行者数量进行取整的条件的主要内容,如果未能解决你的问题,请参考以下文章