在 PySpark 数据框中的组中的列上应用函数

Posted

技术标签:

【中文标题】在 PySpark 数据框中的组中的列上应用函数【英文标题】:Apply a function over a column in a group in PySpark dataframe 【发布时间】:2019-08-22 06:20:28 【问题描述】:

我有一个像这样的 PySpark 数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 3      |   14    |
|  1       | -7     |   16    |
|  1       | 5      |   11    |
|  1       | -20    |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | -98    |   124   |
|  2       | -87    |   120   |
|  2       | -1     |   44    |
|  2       |  5     |   1     |
|  2       |  7     |   23    |
-------------------------------

我也有这样的python函数,

def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total

    return result

我想在 id_ 列上对 PySpark 数据框进行分组,并在 p 列上应用函数 fun

我想要类似的东西

spark_df.groupBy('id_')['p'].apply(fun)

我目前正在pyarrow 的帮助下使用 pandas udf 执行此操作,这在我的应用程序的时间方面效率不高。

我正在寻找的结果是,

[4, 7, 0, 5, 0, 5, 11, -98, -87, -1, 5, 7]

这是我正在寻找的结果数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 7      |   14    |
|  1       | 0      |   16    |
|  1       | 5      |   11    |
|  1       | 0      |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | 0      |   124   |
|  2       | 0      |   120   |
|  2       | 0      |   44    |
|  2       |  5     |   1     |
|  2       |  12    |   23    |
-------------------------------

pyspark API 本身有直接的方法吗?

我可以使用 collect_listid_ 上进行分组并将p 聚合并列到一个列表中,并在此基础上使用udf 并使用explode 来获取我在结果数据框中需要的列p .

但是如何保留我的数据框中的其他列?

【问题讨论】:

【参考方案1】:

是的,您可以将上述 python 函数转换为 Pyspark UDF。 由于您要返回一个整数数组,因此将返回类型指定为ArrayType(IntegerType()) 很重要。

下面是代码,

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, collect_list

@udf(returnType=ArrayType(IntegerType()))
def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total
    return result.tolist()    # Convert NumPy Array to Python List

由于udf 的输入必须是一个列表,让我们根据“id”对数据进行分组并将行转换为数组。

df = df.groupBy('id_').agg(collect_list('p'))
df = df.toDF('id_', 'p_')    # Assign a new alias name 'p_'
df.show(truncate=False)

输入数据:

+---+------------------------+
|id_|collect_list(p)         |
+---+------------------------+
|1  |[4, 3, -7, 5, -20, 5]   |
|2  |[11, -98, -87, -1, 5, 7]|
+---+------------------------+

接下来,我们在此数据上应用udf

df.select('id_', fun(df.p_)).show(truncate=False)

输出:

+---+--------------------+
|id_|fun(p_)             |
+---+--------------------+
|1  |[4, 7, 0, 5, 0, 5]  |
|2  |[11, 0, 0, 0, 5, 12]|
+---+--------------------+

【讨论】:

但我想要的是输出在列中。我猜我可以使用explode。查看预期数据帧的更新 Q【参考方案2】:

通过以下步骤,我设法达到了我需要的结果,

我的 DataFrame 看起来像这样,

+---+---+---+
|id_|  p|  a|
+---+---+---+
|  1|  4| 12|
|  1|  3| 14|
|  1| -7| 16|
|  1|  5| 11|
|  1|-20| 90|
|  1|  5|120|
|  2| 11|267|
|  2|-98|124|
|  2|-87|120|
|  2| -1| 44|
|  2|  5|  1|
|  2|  7| 23|
+---+---+---+

我将按id_ 上的数据框分组,并收集我想使用collect_list 将函数应用于列表的列,然后像这样应用函数,

agg_df = df.groupBy('id_').agg(F.collect_list('p').alias('collected_p'))
agg_df = agg_df.withColumn('new', fun('collected_p'))

我现在想以某种方式将 agg_df 合并到我的原始数据框。为此,我将首先使用explode 获取行中new 列中的值。

agg_df = agg_df.withColumn('exploded', F.explode('new'))

为了合并,我将使用monotonically_increasing_id 为原始数据框生成idagg_df。我将为每个数据帧创建idx,因为monotonically_increasing_id 对于两个数据帧都不相同。

agg_df = agg_df.withColumn('id_mono', F.monotonically_increasing_id())
df = df.withColumn('id_mono', F.monotonically_increasing_id())

w = Window().partitionBy(F.lit(0)).orderBy('id_mono')

df = df.withColumn('idx', F.row_number().over(w))
agg_df = agg_df.withColumn('idx', F.row_number().over(w))

df = df.join(agg_df.select('idx', 'exploded'), ['idx']).drop('id_mono', 'idx')


+---+---+---+--------+
|id_|  p|  a|exploded|
+---+---+---+--------+
|  1|  4| 12|       4|
|  1|  3| 14|       7|
|  1| -7| 16|       0|
|  1|  5| 11|       5|
|  1|-20| 90|       0|
|  1|  5|120|       5|
|  2| 11|267|      11|
|  2|-98|124|       0|
|  2|-87|120|       0|
|  2| -1| 44|       0|
|  2|  5|  1|       5|
|  2|  7| 23|      12|
+---+---+---+--------+

我不确定这是一种直接的方法。如果有人可以为此提出任何优化建议,那就太好了。

【讨论】:

以上是关于在 PySpark 数据框中的组中的列上应用函数的主要内容,如果未能解决你的问题,请参考以下文章

Pyspark:迭代数据框中的组

在 pyspark 中的特定列上应用过滤器描述

如何拆分对象列表以分隔pyspark数据框中的列

PySpark:将 RDD 转换为数据框中的列

遍历 pyspark 数据框中的列,而不为单个列创建不同的数据框

使用 pyspark 将 Spark 数据框中的列转换为数组 [重复]