避免在 pyspark 代码中使用 collect() 函数的最佳方法是啥?编写优化pyspark代码的最佳方法?

Posted

技术标签:

【中文标题】避免在 pyspark 代码中使用 collect() 函数的最佳方法是啥?编写优化pyspark代码的最佳方法?【英文标题】:What is the best way to avoid using collect() function in the pyspark code? Best ways to write optimize pyspark code?避免在 pyspark 代码中使用 collect() 函数的最佳方法是什么?编写优化pyspark代码的最佳方法? 【发布时间】:2017-10-12 07:26:10 【问题描述】:

您好,我正在尝试在 pyspark 中编写代码以从数据框创建列表。我在我的代码中使用 collect() 函数,但不确定它是否是从数据框列中获取值的过滤器列表的正确方法。由于 collect() 将数据带入数据节点,因此在大数据帧(例如 10 GB)的情况下,这将是一个糟糕的选择。

以下是我的输入数据框 -

[Row(parent=u'p1', child=u'c1'), Row(parent=u'p11', child=u'p1'),
Row(parent=u'p111', child=u'p11'), Row(parent=u'p2', child=u'c2'), 
Row(parent=u'p22', child=u'p2'), Row(parent=u'p222', child=u'p22'),
Row(parent=u'p2222', child=u'p222')]

我想实现如下输出数据帧 -

[Row(parent=u'p2222', child1=u'p222', child2=u'p22', child3=u'p2',
child4=u'c2'), Row(parent=u'p111', child1=u'p11', child2=u'p1',
child3=u'c1', child4=None)]

以下是我编写的工作代码,但不确定它是否优化为 spark 以优化处理而闻名

from pyspark.sql import * 
from pyspark.sql.functions import * 
from pyspark.sql.types import *

customSchema = StructType([StructField('parent',StringType(),True),\
                           StructField('child',StringType(),True)]) 

#loading data from a CSV file and creating a dataframe
mydata = sqlContext.load(source='com.databricks.spark.csv',path='/FileStore/tables/34v0qouq1507635707462/parent_child_input.csv',header=True,schema=customSchema)
mydata.registerTempTable('mydata')

#creating a list of values of column "Child" from the dataframe "mydata"
childlist = [x[1] for x in mydata.collect()] 

#creating another dataframe with filter values of "Parent" column which are not present in childlist
level1 = mydata.selectExpr('parent','child as child1').where(~mydata.parent.isin(childlist)) 
i=1 

#Function to create dataframe containing desired output as mentioned above
def getChild(level1,i):
    cname = 'child'+str(i)
    tmp = [x[i] for x in level1.collect() if x[i]]
    tmp = list(set(tmp))
    if tmp.count(None)==1:
          tmp.remove(None)
    level1.registerTempTable('level1')
    if len(tmp)>0:
          i+=1
          ccname = 'child'+str(i)
          querystr='select level1.*,mydata.child as ' +ccname+\
             ' from level1 left outer join mydata on level1.'+cname+'=mydata.parent'
          level1 = sqlContext.sql(querystr)
          level1 = getChild(level1,i)
return level1

level1 = getChild(level1,i)
level1.drop('child5').show()

【问题讨论】:

【参考方案1】:

如果您的父母和孩子在您输入数据时会有整数,我们可以使用整数进行拆分和分组。尝试了我的方法,希望对您有所帮助。

>>> df.show()
+-----+------+
|child|parent|
+-----+------+
|   c1|    p1|
|   p1|   p11|
|  p11|  p111|
|   c2|    p2|
|   p2|   p22|
|  p22|  p222|
| p222| p2222|
+-----+------+

>>> udf = F.udf(lambda x,y : (x,y),ArrayType(StringType()))
>>> udf1 = F.udf(lambda x : tuple(filter(str.isdigit,x))[0],StringType())

>>> df1 = df.select("*",udf1('parent').alias('group'),udf('parent','child').alias('set'))
>>> df1.show()
+-----+------+-----+-------------+
|child|parent|group|          set|
+-----+------+-----+-------------+
|   c1|    p1|    1|     [p1, c1]|
|   p1|   p11|    1|    [p11, p1]|
|  p11|  p111|    1|  [p111, p11]|
|   c2|    p2|    2|     [p2, c2]|
|   p2|   p22|    2|    [p22, p2]|
|  p22|  p222|    2|  [p222, p22]|
| p222| p2222|    2|[p2222, p222]|
+-----+------+-----+-------------+

>>> udf2 = F.udf(lambda x :sorted(set(sum(x,[])),reverse=True),ArrayType(StringType()))
>>> df2 = df1.groupby('group').agg(udf2(F.collect_set('set')).alias('column'))
>>> df2.show(truncate=False)
+-----+--------------------------+
|group|column                    |
+-----+--------------------------+
|1    |[p111, p11, p1, c1]       |
|2    |[p2222, p222, p22, p2, c2]|
+-----+--------------------------+

maxval = df2[[F.max(F.size('column'))]].first()[0]
schema = StructType([StructField("parent",StringType(),True),StructField("child1",StringType(),True),StructField("child2",StringType(),True),StructField("child3",StringType(),True),StructField("child4",StringType(),True)])

udf3 = F.udf(lambda x : x if len(x) == maxval else x+[None]*(maxval -len(x)),schema)

>>> df2.select("*",udf3('column').alias('merged')).select("merged.*").show()
+------+------+------+------+------+
|parent|child1|child2|child3|child4|
+------+------+------+------+------+
|  p111|   p11|    p1|    c1|  null|
| p2222|  p222|   p22|    p2|    c2|
+------+------+------+------+------+

【讨论】:

以上是关于避免在 pyspark 代码中使用 collect() 函数的最佳方法是啥?编写优化pyspark代码的最佳方法?的主要内容,如果未能解决你的问题,请参考以下文章

在 PySpark 中使用 collect_list 时 Java 内存不足

Dataframe.rdd.map().collect 在 PySpark 中不起作用 [重复]

pyspark .show() 有效,但 .collect() 无效

Pyspark:使用 map 函数而不是 collect 来迭代 RDD

在 PySpark 中连接两个数据框时避免列重复列名

PySpark 和 Jupyter-notebook 中的 Collect() 错误