在 pyspark 中应用 udf 过滤功能

Posted

技术标签:

【中文标题】在 pyspark 中应用 udf 过滤功能【英文标题】:Apply a udf filtering function in pyspark 【发布时间】:2021-04-16 22:55:21 【问题描述】:

我正在尝试对 pyspark 数据帧的特定范围内的值进行过滤和求和。当我使用此代码时它可以工作:

load_1=[]
for i in range(df.count()):
    start_t = df.select('start_time').where(df.id == i).collect()[0][0]
    try:
        load_1.append(df.where((df.start_time <= start_t) & (start_t <= df.end_time_1)).agg("pkt_size":"sum").collect()[0][0])
    except:
        load_1.append(0)

但是速度很慢。我试图用 udf 加速它并这样做:

def get_load(a, df=df):
    try:
        return df.where((df.start_time <= a) & (a <= df.end_time_1)).agg("pkt_size":"sum").collect()[0][0]
    except:
        return 0

loader = f.udf(get_load)
df.withColumn('load_1', loader(df.start_time).show())

使用此方法时出现此错误:

Could not serialize object: TypeError: can't pickle _thread.RLock objects

关于如何解决这个问题或如何加快它的任何想法?我正在尝试做一些类似于我们在 pandas 中执行的函数 apply 的事情。数据非常大(接近 40G),我可以使用的资源越多越好。 提前致谢!以下是数据示例:

+---+-------------+-----------------+--------+
| id|   start_time|       end_time_1|pkt_size|
+---+-------------+-----------------+--------+
|  1|1000000000000| 1.00000000192E12|    66.0|
|  2|1000000000000| 1.00000000192E12|    66.0|
|  3|1000000006478|1.000000008398E12|    66.0|
|  4|1000000006478|1.000000008398E12|    66.0|
|  5|1000000012956|1.000000014556E12|    58.0|
|  6|1000000012956|1.000000014556E12|    58.0|
|  7|1000000012957|1.000000016156E12|  1518.0|
|  8|1000000012957|1.000000016156E12|  1518.0|
|  9|1000000012957|1.000000017756E12|  1518.0|
| 10|1000000012957|1.000000017756E12|  1518.0|
| 11|1000000012957|1.000000019356E12|  1518.0|
| 12|1000000012957|1.000000019356E12|  1518.0|
| 13|1000000012957|1.000000020956E12|  1518.0|
| 14|1000000012957|1.000000020956E12|  1518.0|
| 15|1000000012957|1.000000022556E12|  1518.0|
| 16|1000000012957|1.000000022556E12|  1518.0|
| 17|1000000012957|1.000000024156E12|  1518.0|
| 18|1000000012957|1.000000024156E12|  1518.0|
| 19|1000000012957|1.000000025756E12|  1518.0|
| 20|1000000012957|1.000000025756E12|  1518.0|
+---+-------------+-----------------+--------+
only showing top 20 rows

我们的目标是将所有 start_time 小于每个 id 的 start_time 并且 end_time 大于 id 的 start_time 的行的 pkt_size 相加。所以过滤器是基于每一行的start_time。

【问题讨论】:

您能否提供一些示例数据来说明您要达到的目标? 我刚刚添加了几行。问题是过滤器值会根据每行数据的 start_time 变化,并且应该根据它们的 start_time 计算所有 id。 【参考方案1】:

有一种方法可以在没有循环或udf的情况下实现结果:

使用测试数据

+---+----------+----------+--------+                                            
| id|start_time|end_time_1|pkt_size|
+---+----------+----------+--------+
|  1|         2|         5|       4|
|  2|         1|         6|       5|
|  3|         1|         7|       6|
|  4|         5|         6|       7|
|  5|         4|         7|       8|
|  6|         3|         8|       9|
|  7|         6|         7|      10|
+---+----------+----------+--------+

代码

from pyspark.sql import functions as F

data = [(1, 2, 5, 4), 
        (2, 1, 6, 5),
        (3, 1, 7, 6),
        (4, 5, 6, 7),
        (5, 4, 7, 8),
        (6, 3, 8, 9),
        (7, 6, 7, 10)]

df = spark.createDataFrame(data, schema=['id', 'start_time', 'end_time_1', 'pkt_size'])

df.select('id', 'start_time', 'end_time_1') \
    .join(df.selectExpr('start_time as st_1', 'end_time_1 as et_1', 'pkt_size'), \
        F.expr('st_1 < start_time and et_1 > end_time_1')) \
    .groupBy('id') \
    .agg(F.sum('pkt_size')) \
    .show()

打印

+---+-------------+
| id|sum(pkt_size)|
+---+-------------+
|  7|            9|
|  5|            9|
|  1|           11|
|  4|           23|
+---+-------------+

在此示例中,对于 id 1,添加了第 2 行和第 3 行,对于 id 4,添加了第 3、5 和 6 行。

逻辑与问题中的相同,但计算是对所有 id 并行执行的,而不是一个接一个。这种方法需要自连接,因此 Spark 集群应该足够大。

【讨论】:

以上是关于在 pyspark 中应用 udf 过滤功能的主要内容,如果未能解决你的问题,请参考以下文章

在 PySpark 中重新加载 UDF

PySpark UDF,输入端只有 None 值

无法在 pyspark 中应用标量 pandas udf

PySpark 分组并逐行应用 UDF 操作

Pyspark:在UDF中传递多列以及参数

PySpark 结构化流将 udf 应用于窗口