在 PySpark 数据框中的组中的列上应用函数
Posted
技术标签:
【中文标题】在 PySpark 数据框中的组中的列上应用函数【英文标题】:Apply a function over a column in a group in PySpark dataframe 【发布时间】:2019-08-22 06:20:28 【问题描述】:我有一个像这样的 PySpark 数据框,
+----------+--------+---------+
|id_ | p | a |
+----------+--------+---------+
| 1 | 4 | 12 |
| 1 | 3 | 14 |
| 1 | -7 | 16 |
| 1 | 5 | 11 |
| 1 | -20 | 90 |
| 1 | 5 | 120 |
| 2 | 11 | 267 |
| 2 | -98 | 124 |
| 2 | -87 | 120 |
| 2 | -1 | 44 |
| 2 | 5 | 1 |
| 2 | 7 | 23 |
-------------------------------
我也有这样的python函数,
def fun(x):
total = 0
result = np.empty_like(x)
for i, y in enumerate(x):
total += (y)
if total < 0:
total = 0
result[i] = total
return result
我想在 id_
列上对 PySpark 数据框进行分组,并在 p
列上应用函数 fun
。
我想要类似的东西
spark_df.groupBy('id_')['p'].apply(fun)
我目前正在pyarrow
的帮助下使用 pandas udf 执行此操作,这在我的应用程序的时间方面效率不高。
我正在寻找的结果是,
[4, 7, 0, 5, 0, 5, 11, -98, -87, -1, 5, 7]
这是我正在寻找的结果数据框,
+----------+--------+---------+
|id_ | p | a |
+----------+--------+---------+
| 1 | 4 | 12 |
| 1 | 7 | 14 |
| 1 | 0 | 16 |
| 1 | 5 | 11 |
| 1 | 0 | 90 |
| 1 | 5 | 120 |
| 2 | 11 | 267 |
| 2 | 0 | 124 |
| 2 | 0 | 120 |
| 2 | 0 | 44 |
| 2 | 5 | 1 |
| 2 | 12 | 23 |
-------------------------------
pyspark API 本身有直接的方法吗?
我可以使用 collect_list
在id_
上进行分组并将p
聚合并列到一个列表中,并在此基础上使用udf
并使用explode
来获取我在结果数据框中需要的列p
.
但是如何保留我的数据框中的其他列?
【问题讨论】:
【参考方案1】:是的,您可以将上述 python 函数转换为 Pyspark UDF。
由于您要返回一个整数数组,因此将返回类型指定为ArrayType(IntegerType())
很重要。
下面是代码,
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, collect_list
@udf(returnType=ArrayType(IntegerType()))
def fun(x):
total = 0
result = np.empty_like(x)
for i, y in enumerate(x):
total += (y)
if total < 0:
total = 0
result[i] = total
return result.tolist() # Convert NumPy Array to Python List
由于udf
的输入必须是一个列表,让我们根据“id”对数据进行分组并将行转换为数组。
df = df.groupBy('id_').agg(collect_list('p'))
df = df.toDF('id_', 'p_') # Assign a new alias name 'p_'
df.show(truncate=False)
输入数据:
+---+------------------------+
|id_|collect_list(p) |
+---+------------------------+
|1 |[4, 3, -7, 5, -20, 5] |
|2 |[11, -98, -87, -1, 5, 7]|
+---+------------------------+
接下来,我们在此数据上应用udf
,
df.select('id_', fun(df.p_)).show(truncate=False)
输出:
+---+--------------------+
|id_|fun(p_) |
+---+--------------------+
|1 |[4, 7, 0, 5, 0, 5] |
|2 |[11, 0, 0, 0, 5, 12]|
+---+--------------------+
【讨论】:
但我想要的是输出在列中。我猜我可以使用explode。查看预期数据帧的更新 Q【参考方案2】:通过以下步骤,我设法达到了我需要的结果,
我的 DataFrame 看起来像这样,
+---+---+---+
|id_| p| a|
+---+---+---+
| 1| 4| 12|
| 1| 3| 14|
| 1| -7| 16|
| 1| 5| 11|
| 1|-20| 90|
| 1| 5|120|
| 2| 11|267|
| 2|-98|124|
| 2|-87|120|
| 2| -1| 44|
| 2| 5| 1|
| 2| 7| 23|
+---+---+---+
我将按id_
上的数据框分组,并收集我想使用collect_list
将函数应用于列表的列,然后像这样应用函数,
agg_df = df.groupBy('id_').agg(F.collect_list('p').alias('collected_p'))
agg_df = agg_df.withColumn('new', fun('collected_p'))
我现在想以某种方式将 agg_df
合并到我的原始数据框。为此,我将首先使用explode 获取行中new
列中的值。
agg_df = agg_df.withColumn('exploded', F.explode('new'))
为了合并,我将使用monotonically_increasing_id
为原始数据框生成id
和agg_df
。我将为每个数据帧创建idx
,因为monotonically_increasing_id
对于两个数据帧都不相同。
agg_df = agg_df.withColumn('id_mono', F.monotonically_increasing_id())
df = df.withColumn('id_mono', F.monotonically_increasing_id())
w = Window().partitionBy(F.lit(0)).orderBy('id_mono')
df = df.withColumn('idx', F.row_number().over(w))
agg_df = agg_df.withColumn('idx', F.row_number().over(w))
df = df.join(agg_df.select('idx', 'exploded'), ['idx']).drop('id_mono', 'idx')
+---+---+---+--------+
|id_| p| a|exploded|
+---+---+---+--------+
| 1| 4| 12| 4|
| 1| 3| 14| 7|
| 1| -7| 16| 0|
| 1| 5| 11| 5|
| 1|-20| 90| 0|
| 1| 5|120| 5|
| 2| 11|267| 11|
| 2|-98|124| 0|
| 2|-87|120| 0|
| 2| -1| 44| 0|
| 2| 5| 1| 5|
| 2| 7| 23| 12|
+---+---+---+--------+
我不确定这是一种直接的方法。如果有人可以为此提出任何优化建议,那就太好了。
【讨论】:
以上是关于在 PySpark 数据框中的组中的列上应用函数的主要内容,如果未能解决你的问题,请参考以下文章