Pyspark - 通过用户定义的聚合函数和旋转创建数据框
Posted
技术标签:
【中文标题】Pyspark - 通过用户定义的聚合函数和旋转创建数据框【英文标题】:Pyspark - Creating a dataframe by user defined aggregate function and pivoting 【发布时间】:2018-07-23 16:27:38 【问题描述】:我需要编写一个用户定义的聚合函数,用于捕获每次连续访问的上一次出院日期和下一次准入日期之间的天数。
我还需要以“PERSON_ID”值为中心。
我有以下 input_df :
input_df :
+---------+----------+--------------+
|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|
+---------+----------+--------------+
| 111|2018-03-15| 2018-03-16|
| 333|2018-06-10| 2018-06-11|
| 111|2018-03-01| 2018-03-02|
| 222|2018-12-01| 2018-12-02|
| 222|2018-12-05| 2018-12-06|
| 111|2018-03-30| 2018-03-31|
| 333|2018-06-01| 2018-06-02|
| 333|2018-06-20| 2018-06-21|
| 111|2018-01-01| 2018-01-02|
+---------+----------+--------------+
首先,我需要按每个人分组,并按 ADMIT_DATE 对相应的行进行排序。这将产生“input_df2”。
input_df2:
+---------+----------+--------------+
|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|
+---------+----------+--------------+
| 111|2018-01-01| 2018-01-03|
| 111|2018-03-01| 2018-03-02|
| 111|2018-03-15| 2018-03-16|
| 111|2018-03-30| 2018-03-31|
| 222|2018-12-01| 2018-12-02|
| 222|2018-12-05| 2018-12-06|
| 333|2018-06-01| 2018-06-02|
| 333|2018-06-10| 2018-06-11|
| 333|2018-06-20| 2018-06-21|
+---------+----------+--------------+
The desired output_df :
+------------------+-----------------+-----------------+----------------+
|PERSON_ID_DISTINCT| FIRST_DIFFERENCE|SECOND_DIFFERENCE|THIRD_DIFFERENCE|
+------------------+-----------------+-----------------+----------------+
| 111| 1 month 26 days | 13 days| 14 days|
| 222| 3 days| NAN| NAN|
| 333| 8 days| 9 days| NAN|
+------------------+-----------------+-----------------+----------------+
我知道一个人在我的 input_df 中出现的最大数量,所以我知道应该创建多少列:
print input_df.groupBy('PERSON_ID').count().sort('count', ascending=False).show(5)
非常感谢,
【问题讨论】:
【参考方案1】:您可以使用pyspark.sql.functions.datediff()
计算两个日期之间的天数差。在这种情况下,您只需要计算当前行的ADMIT_DATE
和前一行的DISCHARGE_DATE
之间的差异。您可以通过使用 pyspark.sql.functions.lag()
而不是 Window
来做到这一点。
例如,我们可以将访问之间的持续时间计算为新列DURATION
。
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy('PERSON_ID').orderBy('ADMIT_DATE')
input_df.withColumn(
'DURATION',
f.datediff(f.col('ADMIT_DATE'), f.lag('DISCHARGE_DATE').over(w))
)\
.withColumn('INDEX', f.row_number().over(w)-1)\
.sort('PERSON_ID', 'INDEX')\
.show()
#+---------+----------+--------------+--------+-----+
#|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|DURATION|INDEX|
#+---------+----------+--------------+--------+-----+
#| 111|2018-01-01| 2018-01-02| null| 0|
#| 111|2018-03-01| 2018-03-02| 58| 1|
#| 111|2018-03-15| 2018-03-16| 13| 2|
#| 111|2018-03-30| 2018-03-31| 14| 3|
#| 222|2018-12-01| 2018-12-02| null| 0|
#| 222|2018-12-05| 2018-12-06| 3| 1|
#| 333|2018-06-01| 2018-06-02| null| 0|
#| 333|2018-06-10| 2018-06-11| 8| 1|
#| 333|2018-06-20| 2018-06-21| 9| 2|
#+---------+----------+--------------+--------+-----+
注意,我还使用pyspark.sql.functions.row_number()
添加了一个INDEX
列。我们可以只过滤INDEX > 0
(因为第一个值总是null
),然后只旋转DataFrame:
input_df.withColumn(
'DURATION',
f.datediff(f.col('ADMIT_DATE'), f.lag('DISCHARGE_DATE').over(w))
)\
.withColumn('INDEX', f.row_number().over(w) - 1)\
.where('INDEX > 0')\
.groupBy('PERSON_ID').pivot('INDEX').agg(f.first('DURATION'))\
.sort('PERSON_ID')\
.show()
#+---------+---+----+----+
#|PERSON_ID| 1| 2| 3|
#+---------+---+----+----+
#| 111| 58| 13| 14|
#| 222| 3|null|null|
#| 333| 8| 9|null|
#+---------+---+----+----+
现在您可以将列重命名为您想要的任何名称。
注意:这里假设 ADMIT_DATE
和 DISCHARGE_DATE
的类型为 date
。
input_df.printSchema()
#root
# |-- PERSON_ID: long (nullable = true)
# |-- ADMIT_DATE: date (nullable = true)
# |-- DISCHARGE_DATE: date (nullable = true)
【讨论】:
保罗,非常感谢您的回答,效果很好!我还有一个关于在 for 循环中动态重命名多个列的问题。我检查了@zero323 的答案。假设目前我们有output_df2
。这种方式有效:mapping =dict(zip(['1', '2','3'], ['1_dif_time', '2_dif_time', '3_dif_time' ]))
然后output_df2 = output_df2.select([col(c).alias(mapping.get(c, c)) for c in output_df2.columns])
但是,我想以更有效的方式在 for 循环中执行此操作。你能帮帮我吗?
你现在的做法有什么问题?这几乎是最有效的。一种替代方法是for c in '1', '2', '3': output_df2 = output_df2.withColumnRenamed(c, c + "_dif_time")
。在 cmets 中很难回答后续问题。如果您有进一步的疑问,请考虑提出一个新问题,但请务必在问题中澄清为什么它不是现有帖子的重复。以上是关于Pyspark - 通过用户定义的聚合函数和旋转创建数据框的主要内容,如果未能解决你的问题,请参考以下文章
尝试通过数据框在 Pyspark 中执行用户定义的函数时出错