如何在 PySpark 中找到数组数组的平均值

Posted

技术标签:

【中文标题】如何在 PySpark 中找到数组数组的平均值【英文标题】:How to find average of array of arrays in PySpark 【发布时间】:2019-12-10 01:20:06 【问题描述】:

我有一个 PySpark 数据框,其中一列(比如B)是一个数组数组。以下是 PySpark 数据框:

+---+-----------------------------+---+
|A  |B                            |C  |
+---+-----------------------------+---+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |
|a  |[[5.0], [20.0, 80.0]]        |d  |
|a  |[[5.0], [25.0, 75.0]]        |e  |
|b  |[[5.0], [25.0, 75.0]]        |f  |
|b  |[[5.0], [12.0, 88.0]]        |g  |
+---+-----------------------------+---+

我想查找每行的元素数量和所有元素的平均值(作为单独的列)。

下面是预期的输出:

+---+-----------------------------+---+---+------+
|A  |B                            |C  |Num|   Avg|
+---+-----------------------------+---+---+------+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |4  | 23.75|
|a  |[[5.0], [20.0, 80.0]]        |d  |3  | 35.00|
|a  |[[5.0], [25.0, 75.0]]        |e  |3  | 35.00|
|b  |[[5.0], [25.0, 75.0]]        |f  |3  | 35.00|
|b  |[[5.0], [12.0, 88.0]]        |g  |3  | 35.00|
+---+-----------------------------+---+---+------+

在 PySpark 中查找数组数组(每行)中所有元素的平均值的有效方法是什么?

目前,我正在使用 udf 来执行这些操作。以下是我目前拥有的代码:

from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np

#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
    return sum([len(array_element) for array_element in anomaly_in_issue_group_col])

udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )

#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
    return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )

udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )

df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
    "Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)

用于查找每行中元素数量的 udf 有效。但是,用于查找平均值的 udf 会引发以下错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
      1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" ,  udf_avg_array_of_arrays(F.col("B")) ).show(20, False)

/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    378             print(self._jdf.showString(n, 20, vertical))
    379         else:
--> 380             print(self._jdf.showString(n, int(truncate), vertical))
    381 
    382     def __repr__(self):

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling 012.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

【问题讨论】:

有。您尝试了组合爆炸 (spark.apache.org/docs/latest/api/python/…) 和窗口 (spark.apache.org/docs/latest/api/python/…) 函数。 【参考方案1】:

对于 spark 2.4+,使用flatten + aggregate:

from pyspark.sql.functions import expr

df.withColumn("Avg", expr("""
    aggregate(
        flatten(B)
      , (double(0) as total, int(0) as cnt)
      , (x,y) -> (x.total+y, x.cnt+1)
      , z -> round(z.total/z.cnt,2)
    ) 
 """)).show()
+-----------------------------+---+-----+
|B                            |C  |Avg  |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c  |23.75|
|[[5.0], [25.0, 80.0]]        |d  |36.67|
|[[5.0], [25.0, 75.0]]        |e  |35.0 |
+-----------------------------+---+-----+

【讨论】:

【参考方案2】:

自 Spark 1.4

explode() 包含数组的列,与嵌套级别一样多。使用monotonically_increasing_id() 创建额外的分组键以防止重复行被组合:

from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id

df = spark.createDataFrame(
    (("a", [[1], [2, 3], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),  # DUPE!
     ("b", [[2, 3], [4]], "foo")),
    schema=("category", "arrays", "foo"))

df2 = (df.withColumn("id", monotonically_increasing_id())
       .withColumn("subarray", explode("arrays"))
       .withColumn("subarray", explode("subarray"))  # unnest another level
       .groupBy("category", "arrays", "foo", "id")
       .agg(sum(lit(1)).alias("number_of_elements"),
            avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+  
# |category|            arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       b|     [[2, 3], [4]]|foo|                 3| 3.0|
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       a|[[1], [2, 3], [4]]|foo|                 4| 2.5|
# +--------+------------------+---+------------------+----+

Spark 2.4 引入了 24 个处理复杂类型的函数,以及高阶函数(将函数作为参数的函数,如 Python 3 的 functools.reduce)。他们拿走了你在上面看到的样板。如果您使用的是 Spark2.4+,请参阅answer from jxc。

【讨论】:

以上是关于如何在 PySpark 中找到数组数组的平均值的主要内容,如果未能解决你的问题,请参考以下文章

如何根据pyspark中的索引查找数组列的平均值

如何在pyspark的列中找到列表的平均值?

如何在 MASM 中找到浮点数组的平均值?

如何从php中的数组中找到平均值?

如何找到numpy数组的每两行的平均值

我如何找到这个学生数组中所有学生的最高、最低和总平均数