pyspark 从数据帧迭代 N 行到每次执行

Posted

技术标签:

【中文标题】pyspark 从数据帧迭代 N 行到每次执行【英文标题】:pyspark iterate N rows from Data Frame to each execution 【发布时间】:2020-09-16 20:51:58 【问题描述】:
def fun_1(csv):
     # returns int[] of length = Number of New Lines in String csv

def fun_2(csv): # My WorkArround to Pass one CSV Line at One Time
     return fun_1(csv)[0]

输入数据框是df

+----+----+-----+
|col1|col2|CSVs |
+----+----+-----+
|   1|   a|2,0,1|
|   2|   b|2,0,2|
|   3|   c|2,0,3|
|   4|   a|2,0,1|
|   5|   b|2,0,2|
|   6|   c|2,0,3|
|   7|   a|2,0,1|
+----+----+-----+ 

下面是一个有效但需要很长时间的代码片段

from pyspark.sql.functions import udf
from pyspark.sql import functions as sf
funudf = udf(fun_2) # wish it could be fun_1
df=df.withColumn( 'pred' , funudf(sf.col('csv')))

fun_1 ,存在内存问题,一次最多只能处理 50000 行。我希望使用 funudf = udf(fun_1) 。 因此,如何将 PySpark DF 拆分为 50000 行的段,调用 funudf ->fun_1。 输出有两个列,来自输入的“col1”和“funudf 返回值”。

【问题讨论】:

您在运行udf(fun_1) 之前是否尝试过重新分区数据?您看到的内存问题到底是什么? 呼叫转到另一个计算具有挑战性的内存饥饿的服务。 如果我理解正确,fun_1 会调用此服务,该服务对 CSV 执行一些复杂的操作,并且会占用内存。通常,减少内存需求的最简单解决方案是使用随机密钥进行重新分区。它将默认为 200 个分区。因此,您可以在运行 UDF 之前尝试 df=df.repartition(800, 'some_key')。确保在fun_1 之前运行count 之类的操作,因为repartition 是惰性的。 【参考方案1】:

您可以通过使用 RDD API 中公开的groupByKey 方法来实现强制 PySpark 对固定批次的行进行操作的预期结果。使用 groupByKey 将强制 PySpark 将单个密钥的所有数据随机分配给单个执行程序。

注意:出于同样的原因,由于网络成本,通常不鼓励使用 groupByKey

策略:

    添加一列,将您的数据分组到所需的批次大小和groupByKey 定义一个函数来重现您的 UDF 的逻辑(并返回一个 id 以便稍后加入)。这在pyspark.resultiterable.ResultIterable 上运行,groupByKey 的结果。使用mapValues 将功能应用于您的组 将生成的 RDD 转换为 DataFrame 并重新加入。

例子:

# Synthesize DF
data = '_id': range(9), 'group': ['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c'], 'vals': [2.0*i for i in range(9)]
df = spark.createDataFrame(pd.DataFrame(data))

df.show()

##
# Step - 1 Convert to rdd and groupByKey to force each group to separate executor
##
kv = df.rdd.map(lambda r: (r.group, [r._id, r.group, r.vals]))
groups = kv.groupByKey()

##
# Step 2 - Calulate function
##

# Dummy function taking 
def mult3(ditr):
    data = ditr.data
    ids = [v[0] for v in data]
    vals = [3*v[2] for v in data]
    return zip(ids, vals)

# run mult3 and flaten results
mv = groups.mapValues(mult3).map(lambda r: r[1]).flatMap(lambda r: r) # rdd[(id, val)]

## 
# Step 3 - Join results back into base DF
## 

# convert results into a DF and join back in
schema = t.StructType([t.StructField('_id', t.LongType()), t.StructField('vals_x_3', t.FloatType())])
df_vals = spark.createDataFrame(mv, schema)
joined = df.join(df_vals, '_id')

joined.show()

>>>

+---+-----+----+
|_id|group|vals|
+---+-----+----+
|  0|    a| 0.0|
|  1|    b| 2.0|
|  2|    c| 4.0|
|  3|    a| 6.0|
|  4|    b| 8.0|
|  5|    c|10.0|
|  6|    a|12.0|
|  7|    b|14.0|
|  8|    c|16.0|
+---+-----+----+

+---+-----+----+--------+
|_id|group|vals|vals_x_3|
+---+-----+----+--------+
|  0|    a| 0.0|     0.0|
|  7|    b|14.0|    42.0|
|  6|    a|12.0|    36.0|
|  5|    c|10.0|    30.0|
|  1|    b| 2.0|     6.0|
|  3|    a| 6.0|    18.0|
|  8|    c|16.0|    48.0|
|  2|    c| 4.0|    12.0|
|  4|    b| 8.0|    24.0|
+---+-----+----+--------+

【讨论】:

让我试试,请问你有办法使用 DataFrame 代替 RDD 不,我只知道基于 RDD 的方法。

以上是关于pyspark 从数据帧迭代 N 行到每次执行的主要内容,如果未能解决你的问题,请参考以下文章

Pyspark 在查找前一行时按组迭代数据帧

PySpark:我们应该迭代更新数据帧吗?

为啥在使用 pyspark 加入 Spark 数据帧时出现这些 Py4JJavaError showString 错误?

将 PySpark 数据帧写入 Parquet 文件时出现 Py4JJavaError

Spark迭代算法UDF在每次迭代中被多次触发

Spark:如何在每个执行程序中创建本地数据帧