在pyspark中检索每组DataFrame中的前n个
Posted
技术标签:
【中文标题】在pyspark中检索每组DataFrame中的前n个【英文标题】:Retrieve top n in each group of a DataFrame in pyspark 【发布时间】:2016-11-18 18:19:27 【问题描述】:pyspark中有一个DataFrame,数据如下:
user_id object_id score
user_1 object_1 3
user_1 object_1 1
user_1 object_2 2
user_2 object_1 5
user_2 object_2 2
user_2 object_2 6
我期望的是在每个组中返回 2 条具有相同 user_id 的记录,这些记录需要获得最高分。因此,结果应如下所示:
user_id object_id score
user_1 object_1 3
user_1 object_2 2
user_2 object_2 6
user_2 object_1 5
我真的是 pyspark 的新手,谁能给我一个代码 sn-p 或门户到这个问题的相关文档?非常感谢!
【问题讨论】:
【参考方案1】:我认为您需要使用window functions 来获得基于user_id
和score
的每一行的排名,然后过滤您的结果以仅保留前两个值。
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
总的来说,官方programming guide是开始学习Spark的好地方。
数据
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
【讨论】:
我认为有些地方需要调整。 object_id 对groupby
或top
过程都没有影响。而我想要的是group by
user_id,并在每个组中分别检索得分最高的前两个记录,而不仅仅是第一条记录。非常感谢!
可以在过滤器中使用窗口函数:df.filter(rank().over(window) <= 2)
我大吃一惊...我确信我之前在过滤器中使用了窗口函数。但我确实无法重现它(在 2 和 1.6 中都没有)。我确实以一种异国情调的方式使用它,但我不记得何时或如何使用它。对不起。
您可能需要考虑使用row_number
而不是rank
,以防获得相同的排名并且您仍然想要前n
@TomerBenDavid 此评论值得更多支持,谢谢先生。【参考方案2】:
如果在获得排名相等时使用row_number
而不是rank
,则Top-n 更准确:
val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
.where(col('row_number') <= n) \
.limit(20) \
.toPandas()
注意
limit(20).toPandas()
技巧而不是show()
用于 Jupyter 笔记本以获得更好的格式。
【讨论】:
记得添加from pyspark.sql.functions import row_number
以使其正常工作
什么会更有效(快速)计算?我怀疑是差不多的。这会是一种更有效的方法吗?我正在处理一个 110 GB 的数据集,其中包含 470 万个类别(到 groupBy),每个类别大约有 4,300 行,并且它永远在一个大型集群上。
这里是描述difference between rank, row_number, and dense_rank的最佳链接【参考方案3】:
我知道这个问题是针对pyspark
提出的,我在Scala
中寻找类似的答案,即
在 Scala 中检索每组 DataFrame 中的前 n 个值
这是@mtoto 答案的scala
版本。
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col
val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show()
# you can change the value 2 to any number you want. Here 2 represents the top 2 values
更多示例可以在here 找到。
【讨论】:
【参考方案4】:这是另一种没有窗口函数的解决方案,用于从 pySpark DataFrame 中获取前 N 条记录。
# Import Libraries
from pyspark.sql.functions import col
# Sample Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)
# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)
# Display DataFrame
sorted_df.show()
输出
+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2| 2|
| user_2| object_2| 2|
| user_1| object_1| 3|
| user_2| object_1| 5|
| user_2| object_2| 6|
+-------+---------+-----+
如果您对 Spark 中更多的窗口函数感兴趣,可以参考我的一篇博客:https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86
【讨论】:
【参考方案5】:使用 Python 3 和 Spark 2.4
from pyspark.sql import Window
import pyspark.sql.functions as f
def get_topN(df, group_by_columns, order_by_column, n=1):
window_group_by_columns = Window.partitionBy(group_by_columns)
ordered_df = df.select(df.columns + [
f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
topN_df = ordered_df.filter(f"row_rank <= n").drop("row_rank")
return topN_df
top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1)
【讨论】:
【参考方案6】:使用ROW_NUMBER()
函数在 PYSPARK SQLquery 中查找第 N 个最大值:
SELECT * FROM (
SELECT e.*,
ROW_NUMBER() OVER (ORDER BY col_name DESC) rn
FROM Employee e
)
WHERE rn = N
N 是列中要求的第 n 个最大值
输出:
[Stage 2:> (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name |
+-----------+
|1183395 |
+-----------+
查询将返回N个最大值
【讨论】:
以上是关于在pyspark中检索每组DataFrame中的前n个的主要内容,如果未能解决你的问题,请参考以下文章