如何选择每组的第一行?
Posted
技术标签:
【中文标题】如何选择每组的第一行?【英文标题】:How to select the first row of each group? 【发布时间】:2015-11-23 18:49:25 【问题描述】:我有一个 DataFrame 生成如下:
df.groupBy($"Hour", $"Category")
.agg(sum($"value") as "TotalValue")
.sort($"Hour".asc, $"TotalValue".desc))
结果如下:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 0| cat13| 22.1|
| 0| cat95| 19.6|
| 0| cat105| 1.3|
| 1| cat67| 28.5|
| 1| cat4| 26.8|
| 1| cat13| 12.6|
| 1| cat23| 5.3|
| 2| cat56| 39.6|
| 2| cat40| 29.7|
| 2| cat187| 27.9|
| 2| cat68| 9.8|
| 3| cat8| 35.6|
| ...| ....| ....|
+----+--------+----------+
如您所见,DataFrame 按Hour
升序排列,然后按TotalValue
降序排列。
我想选择每个组的第一行,即
从 Hour==0 组中选择 (0,cat26,30.9) 从 Hour==1 组中选择 (1,cat67,28.5) 从 Hour==2 组中选择 (2,cat56,39.6) 等等所以期望的输出是:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
| ...| ...| ...|
+----+--------+----------+
如果还能够选择每个组的前 N 行可能会很方便。
非常感谢任何帮助。
【问题讨论】:
【参考方案1】:窗口函数:
这样的事情应该可以解决问题:
import org.apache.spark.sql.functions.row_number, max, broadcast
import org.apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
在数据严重倾斜的情况下,此方法效率低下。此问题由 SPARK-34775 跟踪,将来可能会得到解决 (SPARK-37099)。
纯 SQL 聚合,后跟 join
:
您也可以加入聚合数据框:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
它将保留重复值(如果每小时有多个类别具有相同的总值)。您可以按如下方式删除它们:
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
对structs
使用排序:
虽然没有经过很好的测试,但不需要连接或窗口函数的技巧:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
使用 DataSet API(Spark 1.6+、2.0+):
Spark 1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0 或更高版本:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
最后两种方法可以利用 map 端合并,并且不需要完全 shuffle,因此与窗口函数和连接相比,大多数时间应该表现出更好的性能。这些手杖也可以在completed
输出模式下与结构化流一起使用。
不要使用:
df.orderBy(...).groupBy(...).agg(first(...), ...)
它可能看起来有效(尤其是在local
模式下),但它并不可靠(请参阅SPARK-16207,linking relevant JIRA issue 和SPARK-30335 归功于Tzach Zohar)。
同样的注释适用于
df.orderBy(...).dropDuplicates(...)
内部使用等效的执行计划。
【讨论】:
从 spark 1.6 开始看起来是 row_number() 而不是 rowNumber 关于不要使用 df.orderBy(...).gropBy(...)。什么情况下可以依赖orderBy(...)?或者如果我们不能确定 orderBy() 是否会给出正确的结果,我们有什么替代方案? 我可能忽略了一些东西,但总的来说建议avoid groupByKey,而不是使用reduceByKey。此外,您将节省一行。 @Thomas 避免 groupBy/groupByKey 只是在处理 RDD 时,您会注意到 Dataset api 甚至没有 reduceByKey 函数。 @Thomas DataFrame / Dataset groupBy behaviour/optimization【参考方案2】:对于按多列分组的 Spark 2.0.2:
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
【讨论】:
这段代码或多或少包含在Apache DataFu'sdedupWithOrder method中【参考方案3】:这与zero323 的answer 完全相同,但采用SQL 查询方式。
假设数据框已创建并注册为
df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0 |cat26 |30.9 |
//|0 |cat13 |22.1 |
//|0 |cat95 |19.6 |
//|0 |cat105 |1.3 |
//|1 |cat67 |28.5 |
//|1 |cat4 |26.8 |
//|1 |cat13 |12.6 |
//|1 |cat23 |5.3 |
//|2 |cat56 |39.6 |
//|2 |cat40 |29.7 |
//|2 |cat187 |27.9 |
//|2 |cat68 |9.8 |
//|3 |cat8 |35.6 |
//+----+--------+----------+
窗口功能:
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
简单的 SQL 聚合,然后是连接:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
"(select Hour, Category, TotalValue from table tmp1 " +
"join " +
"(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
"on " +
"tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
"group by tmp3.Hour")
.show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
对结构使用排序:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
DataSets way 和 don't dos 与原始答案中的相同
【讨论】:
【参考方案4】:您可以在 Spark 3.0 中使用max_by()
函数!
https://spark.apache.org/docs/3.0.0-preview/api/sql/index.html#max_by
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("table")
// Using SQL
val result = spark.sql("select Hour, max_by(Category, TotalValue) AS Category, max(TotalValue) as TotalValue FROM table group by Hour order by Hour")
// or Using DataFrame API
val result = df.groupBy("Hour").
agg(expr("max_by(Category, TotalValue)").as("Category"), max("TotalValue").as("TotalValue")).
sort("Hour")
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
+----+--------+----------+
【讨论】:
【参考方案5】:模式是 按键分组 => 对每个组做一些事情,例如reduce => 返回数据框
我认为在这种情况下,Dataframe 抽象有点麻烦,所以我使用了 RDD 功能
val rdd: RDD[Row] = originalDf
.rdd
.groupBy(row => row.getAs[String]("grouping_row"))
.map(iterableTuple =>
iterableTuple._2.reduce(reduceFunction)
)
val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
【讨论】:
【参考方案6】:您可以使用Apache DataFu 轻松做到这一点(实现类似于Antonin's answer)。
import datafu.spark.DataFrameOps._
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
df.dedupWithOrder($"Hour", $"TotalValue".desc).show
这将导致
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 3| cat8| 35.6|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
+----+--------+----------+
(是的,结果不会按小时排序,但如果它很重要,你可以稍后再做)
还有一个 API - dedupTopN - 用于获取前 N 行。还有另一个 API - dedupWithCombiner - 当您期望每个分组有大量行时。
(完全披露 - 我是 DataFu 项目的一部分)
【讨论】:
【参考方案7】:下面的解决方案只执行一个 groupBy 并一次提取包含 maxValue 的数据帧行。无需进一步的联接或 Windows。
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame
//df is the dataframe with Day, Category, TotalValue
implicit val dfEnc = RowEncoder(df.schema)
val res: DataFrame = df.groupByKey(r) => r.getInt(0).mapGroups[Row](day: Int, rows: Iterator[Row]) => i.maxBy(r) => r.getDouble(2)
【讨论】:
但它首先会洗牌。它几乎没有改进(可能不比窗口函数差,具体取决于数据)。 你有一个小组第一名,这将触发随机播放。它并不比窗口函数差,因为在窗口函数中,它将评估数据框中每一行的窗口。【参考方案8】:使用 dataframe api 执行此操作的一个好方法是使用 argmax 逻辑,如下所示
val df = Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
df.groupBy($"Hour")
.agg(max(struct($"TotalValue", $"Category")).as("argmax"))
.select($"Hour", $"argmax.*").show
+----+----------+--------+
|Hour|TotalValue|Category|
+----+----------+--------+
| 1| 28.5| cat67|
| 3| 35.6| cat8|
| 2| 39.6| cat56|
| 0| 30.9| cat26|
+----+----------+--------+
【讨论】:
【参考方案9】:在这里你可以这样做 -
val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
data.withColumnRenamed("_1","Hour").show
【讨论】:
以上是关于如何选择每组的第一行?的主要内容,如果未能解决你的问题,请参考以下文章