在 pyspark 中应用 udf 过滤功能
Posted
技术标签:
【中文标题】在 pyspark 中应用 udf 过滤功能【英文标题】:Apply a udf filtering function in pyspark 【发布时间】:2021-04-16 22:55:21 【问题描述】:我正在尝试对 pyspark 数据帧的特定范围内的值进行过滤和求和。当我使用此代码时它可以工作:
load_1=[]
for i in range(df.count()):
start_t = df.select('start_time').where(df.id == i).collect()[0][0]
try:
load_1.append(df.where((df.start_time <= start_t) & (start_t <= df.end_time_1)).agg("pkt_size":"sum").collect()[0][0])
except:
load_1.append(0)
但是速度很慢。我试图用 udf 加速它并这样做:
def get_load(a, df=df):
try:
return df.where((df.start_time <= a) & (a <= df.end_time_1)).agg("pkt_size":"sum").collect()[0][0]
except:
return 0
loader = f.udf(get_load)
df.withColumn('load_1', loader(df.start_time).show())
使用此方法时出现此错误:
Could not serialize object: TypeError: can't pickle _thread.RLock objects
关于如何解决这个问题或如何加快它的任何想法?我正在尝试做一些类似于我们在 pandas 中执行的函数 apply 的事情。数据非常大(接近 40G),我可以使用的资源越多越好。 提前致谢!以下是数据示例:
+---+-------------+-----------------+--------+
| id| start_time| end_time_1|pkt_size|
+---+-------------+-----------------+--------+
| 1|1000000000000| 1.00000000192E12| 66.0|
| 2|1000000000000| 1.00000000192E12| 66.0|
| 3|1000000006478|1.000000008398E12| 66.0|
| 4|1000000006478|1.000000008398E12| 66.0|
| 5|1000000012956|1.000000014556E12| 58.0|
| 6|1000000012956|1.000000014556E12| 58.0|
| 7|1000000012957|1.000000016156E12| 1518.0|
| 8|1000000012957|1.000000016156E12| 1518.0|
| 9|1000000012957|1.000000017756E12| 1518.0|
| 10|1000000012957|1.000000017756E12| 1518.0|
| 11|1000000012957|1.000000019356E12| 1518.0|
| 12|1000000012957|1.000000019356E12| 1518.0|
| 13|1000000012957|1.000000020956E12| 1518.0|
| 14|1000000012957|1.000000020956E12| 1518.0|
| 15|1000000012957|1.000000022556E12| 1518.0|
| 16|1000000012957|1.000000022556E12| 1518.0|
| 17|1000000012957|1.000000024156E12| 1518.0|
| 18|1000000012957|1.000000024156E12| 1518.0|
| 19|1000000012957|1.000000025756E12| 1518.0|
| 20|1000000012957|1.000000025756E12| 1518.0|
+---+-------------+-----------------+--------+
only showing top 20 rows
我们的目标是将所有 start_time 小于每个 id 的 start_time 并且 end_time 大于 id 的 start_time 的行的 pkt_size 相加。所以过滤器是基于每一行的start_time。
【问题讨论】:
您能否提供一些示例数据来说明您要达到的目标? 我刚刚添加了几行。问题是过滤器值会根据每行数据的 start_time 变化,并且应该根据它们的 start_time 计算所有 id。 【参考方案1】:有一种方法可以在没有循环或udf的情况下实现结果:
使用测试数据
+---+----------+----------+--------+
| id|start_time|end_time_1|pkt_size|
+---+----------+----------+--------+
| 1| 2| 5| 4|
| 2| 1| 6| 5|
| 3| 1| 7| 6|
| 4| 5| 6| 7|
| 5| 4| 7| 8|
| 6| 3| 8| 9|
| 7| 6| 7| 10|
+---+----------+----------+--------+
代码
from pyspark.sql import functions as F
data = [(1, 2, 5, 4),
(2, 1, 6, 5),
(3, 1, 7, 6),
(4, 5, 6, 7),
(5, 4, 7, 8),
(6, 3, 8, 9),
(7, 6, 7, 10)]
df = spark.createDataFrame(data, schema=['id', 'start_time', 'end_time_1', 'pkt_size'])
df.select('id', 'start_time', 'end_time_1') \
.join(df.selectExpr('start_time as st_1', 'end_time_1 as et_1', 'pkt_size'), \
F.expr('st_1 < start_time and et_1 > end_time_1')) \
.groupBy('id') \
.agg(F.sum('pkt_size')) \
.show()
打印
+---+-------------+
| id|sum(pkt_size)|
+---+-------------+
| 7| 9|
| 5| 9|
| 1| 11|
| 4| 23|
+---+-------------+
在此示例中,对于 id 1,添加了第 2 行和第 3 行,对于 id 4,添加了第 3、5 和 6 行。
逻辑与问题中的相同,但计算是对所有 id 并行执行的,而不是一个接一个。这种方法需要自连接,因此 Spark 集群应该足够大。
【讨论】:
以上是关于在 pyspark 中应用 udf 过滤功能的主要内容,如果未能解决你的问题,请参考以下文章