spark中将每个组作为新数据帧并在循环中传递另一个函数的最佳方法是啥?

Posted

技术标签:

【中文标题】spark中将每个组作为新数据帧并在循环中传递另一个函数的最佳方法是啥?【英文标题】:what is the best way in spark to get each group as a new dataframe and pass on another function in loop?spark中将每个组作为新数据帧并在循环中传递另一个函数的最佳方法是什么? 【发布时间】:2020-03-12 11:32:41 【问题描述】:

我正在使用 spark-sql-2.4.1v,我正在尝试在给定数据的每一列上查找分位数,即百分位数 0、百分位数 25 等。

我的数据:

+----+---------+-------------+----------+-----------+--------+
|  id|     date|      revenue|con_dist_1| con_dist_2| state  |
+----+---------+-------------+----------+-----------+--------+
|  10|1/15/2018|  0.010680705|         6|0.019875458|   TX   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   AZ   |
|  10|1/15/2018|   0.01378215|         4|0.082049528|   TX   |
|  10|1/15/2018|  0.010680705|         6|0.019875458|   TX   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   AZ   |
|  10|1/15/2018|   0.01378215|         4|0.082049528|   CA   |
|  10|1/15/2018|  0.010680705|         6|0.019875458|   CA   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   CA   |
+----+---------+-------------+----------+-----------+--------+

我会让各州计算即

val states = Seq("CA","AZ");
val cols = Seq("con_dist_1" ,"con_dist_2")

对于每个给定的状态,我需要从源表中获取数据并仅计算给定列的百分位数。

我正在尝试如下

for( state <- states)

     for( col <- cols)
        // pecentile calculation
     

这太慢了,当按“状态”进行分组时,不会得到其他列,如收入、日期和 id。如何获得这些?

如何找到每个州的“con_dist_1”和“con_dist_2”列上的分位数?那么在集群上很好地扩展的最佳方式是什么?

处理此用例的最佳方法是什么?

预期结果

+-----+---------------+---------------+---------------+---------------+---------------+---------------+
|state|col1_quantile_1|col1_quantile_2|col1_quantile_3|col2_quantile_1|col2_quantile_2|col2_quantile_3|
+-----+---------------+---------------+---------------+---------------+---------------+---------------+
|   AZ|              4|              4|              4|    0.816039063|    0.816039063|    0.816039063|
|   TX|              4|              6|              6|    0.019875458|    0.019875458|    0.082049528|
+-----+---------------+---------------+---------------+---------------+---------------+---------------+

【问题讨论】:

【参考方案1】:

您可能需要执行类似于以下代码的操作

df.groupBy(col("state"))
    .agg(collect_list(col("con_dist_1")).as("col1_quant"), collect_list(col("con_dist_2")).as("col2_quant"))
    .withColumn("col1_quant1", col("col1_quant")(0))
    .withColumn("col1_quant2", col("col1_quant")(1))
    .withColumn("col2_quant1", col("col2_quant")(0))
    .withColumn("col2_quant2", col("col2_quant")(1))
    .show

OutPut:
+-----+----------+--------------------+-----------+-----------+-----------+-----------+
|state|col1_quant|          col2_quant|col1_quant1|col1_quant2|col2_quant1|col2_quant2|
+-----+----------+--------------------+-----------+-----------+-----------+-----------+
|   AZ|    [4, 4]|[0.816039063, 0.8...|          4|          4|0.816039063|0.816039063|
|   CA|    [4, 6]|[0.082049528, 0.0...|          4|          6|0.082049528|0.019875458|
|   TX| [6, 4, 6]|[0.019875458, 0.0...|          6|          4|0.019875458|0.082049528|
+-----+----------+--------------------+-----------+-----------+-----------+-----------+

可能是最后一组 withColumn 应该在基于每个状态的记录数的循环内。

希望这会有所帮助!

【讨论】:

【参考方案2】:

更新

我从 hive 上下文中找到了 percentile_approx 函数,因此您不需要使用 stat 函数。

val states = Seq("CA", "AZ")
val cols = Seq("con_dist_1", "con_dist_2")

val l = cols.map(c => expr(s"percentile_approx($c, Array(0.25, 0.5, 0.75)) as $c_quantiles"))
val df2 = df.filter($"state".isin(states: _*)).groupBy("state").agg(l.head, l.tail: _*)

df2.select(col("state") +: cols.flatMap( c => (1 until 4).map( i => col(c + "_quantiles")(i - 1).alias(c + "_quantile_" + i))): _*).show(false)

在这里,我尝试了给定statescols 的自动化方法。结果将是;

+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|state|con_dist_1_quantile_1|con_dist_1_quantile_2|con_dist_1_quantile_3|con_dist_2_quantile_1|con_dist_2_quantile_2|con_dist_2_quantile_3|
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|AZ   |4                    |4                    |4                    |0.816039063          |0.816039063          |0.816039063          |
|CA   |4                    |4                    |6                    |0.019875458          |0.082049528          |0.816039063          |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+

请注意,结果与您预期的结果有些不同,因为我设置了您提供的states = Seq("CA", "AZ")


原创

Window 用于状态并计算每列的percent_rank

import org.apache.spark.sql.expressions.Window

val w1 = Window.partitionBy("state").orderBy("con_dist_1")
val w2 = Window.partitionBy("state").orderBy("con_dist_2")
df.withColumn("p1", percent_rank.over(w1))
  .withColumn("p2", percent_rank.over(w2))
  .show(false)

您可以先过滤数据框,仅针对特定状态。反正结果是:

+---+---------+-----------+----------+-----------+-----+---+---+
|id |date     |revenue    |con_dist_1|con_dist_2 |state|p1 |p2 |
+---+---------+-----------+----------+-----------+-----+---+---+
|10 |1/15/2018|0.006628853|4         |0.816039063|AZ   |0.0|0.0|
|10 |1/15/2018|0.006628853|4         |0.816039063|AZ   |0.0|0.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|CA   |1.0|0.0|
|10 |1/15/2018|0.01378215 |4         |0.082049528|CA   |0.0|0.5|
|10 |1/15/2018|0.006628853|4         |0.816039063|CA   |0.0|1.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|TX   |0.5|0.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|TX   |0.5|0.0|
|10 |1/15/2018|0.01378215 |4         |0.082049528|TX   |0.0|1.0|
+---+---------+-----------+----------+-----------+-----+---+---+

【讨论】:

我该回家了,等我到了再看看。 :)

以上是关于spark中将每个组作为新数据帧并在循环中传递另一个函数的最佳方法是啥?的主要内容,如果未能解决你的问题,请参考以下文章

循环遍历 Pandas 数据帧并根据条件复制到新数据帧

Spark中的拆分,操作和联合数据框

如何在 Spark Scala 的 UDF 中将列作为值传递以检查条件

PySpark - 将另一列的值作为 spark 函数的参数传递

将用户定义的函数应用于数据框

如何在for循环中等待每次迭代并在nodeJS中将响应作为API响应返回