在pyspark中检索每组DataFrame中的前n个

Posted

技术标签:

【中文标题】在pyspark中检索每组DataFrame中的前n个【英文标题】:Retrieve top n in each group of a DataFrame in pyspark 【发布时间】:2016-11-18 18:19:27 【问题描述】:

pyspark中有一个DataFrame,数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我期望的是在每个组中返回 2 条具有相同 user_id 的记录,这些记录需要获得最高分。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我真的是 pyspark 的新手,谁能给我一个代码 sn-p 或门户到这个问题的相关文档?非常感谢!

【问题讨论】:

【参考方案1】:

我认为您需要使用window functions 来获得基于user_idscore 的每一行的排名,然后过滤您的结果以仅保留前两个值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

总的来说,官方programming guide是开始学习Spark的好地方。

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

【讨论】:

我认为有些地方需要调整。 object_id 对groupbytop 过程都没有影响。而我想要的是group byuser_id,并在每个组中分别检索得分最高的前两个记录,而不仅仅是第一条记录。非常感谢! 可以在过滤器中使用窗口函数:df.filter(rank().over(window) &lt;= 2) 我大吃一惊...我确信我之前在过滤器中使用了窗口函数。但我确实无法重现它(在 2 和 1.6 中都没有)。我确实以一种异国情调的方式使用它,但我不记得何时或如何使用它。对不起。 您可能需要考虑使用row_number 而不是rank,以防获得相同的排名并且您仍然想要前n @TomerBenDavid 此评论值得更多支持,谢谢先生。【参考方案2】:

如果在获得排名相等时使用row_number 而不是rank,则Top-n 更准确:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

注意 limit(20).toPandas() 技巧而不是 show() 用于 Jupyter 笔记本以获得更好的格式。

【讨论】:

记得添加 from pyspark.sql.functions import row_number 以使其正常工作 什么会更有效(快速)计算?我怀疑是差不多的。这会是一种更有效的方法吗?我正在处理一个 110 GB 的数据集,其中包含 470 万个类别(到 groupBy),每个类别大约有 4,300 行,并且它永远在一个大型集群上。 这里是描述difference between rank, row_number, and dense_rank的最佳链接【参考方案3】:

我知道这个问题是针对pyspark 提出的,我在Scala 中寻找类似的答案,即

在 Scala 中检索每组 DataFrame 中的前 n 个值

这是@mtoto 答案的scala 版本。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

更多示例可以在here 找到。

【讨论】:

【参考方案4】:

这是另一种没有窗口函数的解决方案,用于从 pySpark DataFrame 中获取前 N 条记录。

# Import Libraries
from pyspark.sql.functions import col

# Sample Data
rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)

# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)

# Display DataFrame
sorted_df.show()

输出

+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2|    2|
| user_2| object_2|    2|
| user_1| object_1|    3|
| user_2| object_1|    5|
| user_2| object_2|    6|
+-------+---------+-----+

如果您对 Spark 中更多的窗口函数感兴趣,可以参考我的一篇博客:https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86

【讨论】:

【参考方案5】:

使用 Python 3 和 Spark 2.4

from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= n").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) 

【讨论】:

【参考方案6】:

使用ROW_NUMBER() 函数在 PYSPARK SQLquery 中查找第 N 个最大值:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

N 是列中要求的第 n 个最大值

输出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

查询将返回N个最大值

【讨论】:

以上是关于在pyspark中检索每组DataFrame中的前n个的主要内容,如果未能解决你的问题,请参考以下文章

Pandas:选择每组的前几行

跳过每组中的前 n 行

在pyspark中填充每组的缺失值?

如何在 PySpark 中为一个组迭代 Dataframe / RDD 的每一行。?

获取每组的前 n 个结果 [重复]

执行 pyspark.sql.DataFrame.take(4) 超过一小时