如何在 PySpark 中找到数组数组的平均值
Posted
技术标签:
【中文标题】如何在 PySpark 中找到数组数组的平均值【英文标题】:How to find average of array of arrays in PySpark 【发布时间】:2019-12-10 01:20:06 【问题描述】:我有一个 PySpark 数据框,其中一列(比如B
)是一个数组数组。以下是 PySpark 数据框:
+---+-----------------------------+---+
|A |B |C |
+---+-----------------------------+---+
|a |[[5.0], [25.0, 25.0], [40.0]]|c |
|a |[[5.0], [20.0, 80.0]] |d |
|a |[[5.0], [25.0, 75.0]] |e |
|b |[[5.0], [25.0, 75.0]] |f |
|b |[[5.0], [12.0, 88.0]] |g |
+---+-----------------------------+---+
我想查找每行的元素数量和所有元素的平均值(作为单独的列)。
下面是预期的输出:
+---+-----------------------------+---+---+------+
|A |B |C |Num| Avg|
+---+-----------------------------+---+---+------+
|a |[[5.0], [25.0, 25.0], [40.0]]|c |4 | 23.75|
|a |[[5.0], [20.0, 80.0]] |d |3 | 35.00|
|a |[[5.0], [25.0, 75.0]] |e |3 | 35.00|
|b |[[5.0], [25.0, 75.0]] |f |3 | 35.00|
|b |[[5.0], [12.0, 88.0]] |g |3 | 35.00|
+---+-----------------------------+---+---+------+
在 PySpark 中查找数组数组(每行)中所有元素的平均值的有效方法是什么?
目前,我正在使用 udf 来执行这些操作。以下是我目前拥有的代码:
from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np
#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
return sum([len(array_element) for array_element in anomaly_in_issue_group_col])
udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )
#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )
udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )
df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
"Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)
用于查找每行中元素数量的 udf 有效。但是,用于查找平均值的 udf 会引发以下错误:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" , udf_avg_array_of_arrays(F.col("B")) ).show(20, False)
/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
378 print(self._jdf.showString(n, 20, vertical))
379 else:
--> 380 print(self._jdf.showString(n, int(truncate), vertical))
381
382 def __repr__(self):
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling 012.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
【问题讨论】:
有。您尝试了组合爆炸 (spark.apache.org/docs/latest/api/python/…) 和窗口 (spark.apache.org/docs/latest/api/python/…) 函数。 【参考方案1】:对于 spark 2.4+,使用flatten + aggregate:
from pyspark.sql.functions import expr
df.withColumn("Avg", expr("""
aggregate(
flatten(B)
, (double(0) as total, int(0) as cnt)
, (x,y) -> (x.total+y, x.cnt+1)
, z -> round(z.total/z.cnt,2)
)
""")).show()
+-----------------------------+---+-----+
|B |C |Avg |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c |23.75|
|[[5.0], [25.0, 80.0]] |d |36.67|
|[[5.0], [25.0, 75.0]] |e |35.0 |
+-----------------------------+---+-----+
【讨论】:
【参考方案2】:自 Spark 1.4:
explode()
包含数组的列,与嵌套级别一样多。使用monotonically_increasing_id()
创建额外的分组键以防止重复行被组合:
from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id
df = spark.createDataFrame(
(("a", [[1], [2, 3], [4]], "foo"),
("a", [[5], [6, 0], [4]], "foo"),
("a", [[5], [6, 0], [4]], "foo"), # DUPE!
("b", [[2, 3], [4]], "foo")),
schema=("category", "arrays", "foo"))
df2 = (df.withColumn("id", monotonically_increasing_id())
.withColumn("subarray", explode("arrays"))
.withColumn("subarray", explode("subarray")) # unnest another level
.groupBy("category", "arrays", "foo", "id")
.agg(sum(lit(1)).alias("number_of_elements"),
avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+
# |category| arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# | a|[[5], [6, 0], [4]]|foo| 4|3.75|
# | b| [[2, 3], [4]]|foo| 3| 3.0|
# | a|[[5], [6, 0], [4]]|foo| 4|3.75|
# | a|[[1], [2, 3], [4]]|foo| 4| 2.5|
# +--------+------------------+---+------------------+----+
Spark 2.4 引入了 24 个处理复杂类型的函数,以及高阶函数(将函数作为参数的函数,如 Python 3 的 functools.reduce
)。他们拿走了你在上面看到的样板。如果您使用的是 Spark2.4+,请参阅answer from jxc。
【讨论】:
以上是关于如何在 PySpark 中找到数组数组的平均值的主要内容,如果未能解决你的问题,请参考以下文章