UDF 的性能改进 - 在 pyspark 中获取每行最小值的列名

Posted

技术标签:

【中文标题】UDF 的性能改进 - 在 pyspark 中获取每行最小值的列名【英文标题】:Performance improvement for UDFs - get column name of least value per row in pyspark 【发布时间】:2020-09-10 13:26:22 【问题描述】:

我使用这个 udf:

mincol = F.udf(lambda row: cols[row.index(min(row))], StringType())
df = df.withColumn("mycol", mincol(F.struct([df[x] for x in cols])))

获取每行最小值的列名作为另一个名为“mycol”的列的值。

但是这段代码很慢。 有什么提高性能的建议吗? 我正在使用 Pyspark 2.3

【问题讨论】:

有多慢?多少时间 ?多少行? @Steven 它没有完成一次。接近 1.9 亿行。 能否请您发布示例数据和预期输出 你好@Mysterious,接下来的答案是否提高了您程序的性能? 【参考方案1】:

这是 Spark 2.3 的另一种解决方案,它仅使用内置函数:

from sys import float_info
from pyspark.sql.functions import array, least, col, lit, concat_ws, expr

cols = df.columns
col_names = array(list(map(lit, cols)))
set_cols = list(map(col, cols))

# replace null with largest python float
df.na.fill(float_info.max) \
  .withColumn("min", least(*cols)) \
  .withColumn("cnames", col_names) \
  .withColumn("set", concat_ws(",", *set_cols)) \
  .withColumn("min_col", expr("cnames[find_in_set(min, set) - 1]")) \
  .select(*[cols + ["min_col"]]) \
  .show()

步骤:

    用可能的较大浮点数填充所有nulls。这是一个很好的空值替换候选者,因为很难找到更大的值。 使用least 查找最小列。 创建列cnames 用于存储列名。 创建列set,其中包含以逗号分隔的字符串形式的所有值。 使用find_in_set 创建列min_col。该函数分别处理每个字符串项,并将返回找到的项的索引。最后,我们使用带有cnames[indx - 1] 的索引来检索列名。

【讨论】:

【参考方案2】:

这是一种没有 udf 的方法。这个想法是创建一个包含每列的值和名称的数组,然后对该数组进行排序。

df1 = spark.createDataFrame([
    (1., 2., 3.),(3.,2.,1.), (9.,8.,-1.), (1.2, 1.2, 9.1), (3., None, 1.0)], \
        ["col1", "col2", "col3"])

cols = df1.columns
col_string = ', '.join("'0'".format(c) for c in cols)

df1 = df1.withColumn("vals", F.array(cols)) \
    .withColumn("cols", F.expr("Array(" + col_string + ")")) \
    .withColumn("zipped", F.arrays_zip("vals", "cols")) \
    .withColumn("without_nulls", F.expr("filter(zipped, x -> not x.vals is null)")) \
    .withColumn("sorted", F.expr("array_sort(without_nulls)")) \
    .withColumn("min", F.col("sorted")[0].cols) \
    .drop("vals", "cols", "zipped", "without_nulls", "sorted")
df1.show(truncate=False)

打印

+----+----+----+----+                                                           
|col1|col2|col3|min |
+----+----+----+----+
|1.0 |2.0 |3.0 |col1|
|3.0 |2.0 |1.0 |col3|
|9.0 |8.0 |-1.0|col3|
|1.2 |1.2 |9.1 |col1|
|3.0 |null|1.0 |col3|
+----+----+----+----+

【讨论】:

以上是关于UDF 的性能改进 - 在 pyspark 中获取每行最小值的列名的主要内容,如果未能解决你的问题,请参考以下文章

哪个选项使用 pyspark 提供最佳性能?使用地图进行 UDF 或 RDD 处理?

PySpark - 从 UDF 获取行索引

如何在 Scala Spark 项目中使用 PySpark UDF?

如何使用 Hive 上下文中的 Pyspark 调用用 Java 编写的 Hive UDF

PySpark UDF 返回可变大小的元组

在 PySpark 中重新加载 UDF