collect_list 通过保留基于另一个变量的顺序

Posted

技术标签:

【中文标题】collect_list 通过保留基于另一个变量的顺序【英文标题】:collect_list by preserving order based on another variable 【发布时间】:2017-10-05 07:34:03 【问题描述】:

我正在尝试使用对现有列集的 groupby 聚合在 Pyspark 中创建一个新的列表列。下面提供了一个示例输入数据框:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

预期的输出是:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

列表中的值按日期排序。

我尝试使用 collect_list 如下:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

但是即使我在聚合之前按日期对输入数据帧进行排序,collect_list 也不能​​保证顺序。

有人可以通过保留基于第二个(日期)变量的顺序来帮助进行聚合吗?

【问题讨论】:

【参考方案1】:
from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(w)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))

Window 用户提供的示例通常并不能真正解释发生了什么,所以让我为您剖析一下。

如您所知,将collect_listgroupBy 一起使用将产生一个无序 值列表。这是因为根据您的数据分区方式,Spark 会在找到组中的一行后立即将值附加到您的列表中。然后,顺序取决于 Spark 如何计划您对执行程序的聚合。

Window 函数允许您控制这种情况,按特定值对行进行分组,以便您可以对每个结果组执行操作 over

w = Window.partitionBy('id').orderBy('date')
partitionBy - 您想要具有相同 id 的行组/分区 orderBy - 您希望组中的每一行都按 date 排序

一旦定义了 Window 的范围——“具有相同 id 的行,按 date 排序”——就可以使用它对其执行操作,在本例中为 collect_list

F.collect_list('value').over(w)

此时,您创建了一个新列 sorted_list,其中包含按日期排序的有序值列表,但每个 id 仍然有重复的行。要删除您想要 groupBy id 的重复行并为每个组保留 max 值:

.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))

【讨论】:

由于使用了 Spark 基本功能,这应该是公认的答案 - 非常好! 需要最大值,因为对于相同的“id”,会为每一行创建一个列表,排序顺序为:[10],然后是[10, 5],然后是[10, 5 , 15],然后 [10, 5, 15, 20] 对于 id=1。取列表的最大值取最长的列表(此处为 [10, 5, 15, 20])。 这对内存有什么影响?当我们处理十亿多个事件的链接时,当一个链可以在收集的列表中包含多达 10.000 个项目时,这种方法是否比公认的答案更好? 这不是很膨胀吗?如果我有 1000 万组,每组有 24 个元素。 F.collect_list('value').over(w) 将创建一个从 1 到 24 的新列大小,即 1000 万 * 24 次。然后通过从每个组中获取最大的行来做另一个组。 如果您使用的是collect_set 而不是collect_list,这将不起作用。【参考方案2】:

如果您将日期和值都收集为一个列表,您可以使用和udf 根据日期对结果列进行排序,然后只保留结果中的值。

import operator
import pyspark.sql.functions as F

# create list column
grouped_df = input_df.groupby("id") \
               .agg(F.collect_list(F.struct("date", "value")) \
               .alias("list_col"))

# define udf
def sorter(l):
  res = sorted(l, key=operator.itemgetter(0))
  return [item[1] for item in res]

sort_udf = F.udf(sorter)

# test
grouped_df.select("id", sort_udf("list_col") \
  .alias("sorted_list")) \
  .show(truncate = False)
+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+

【讨论】:

感谢详细的示例...我只是在几百万的更大数据上尝试过,我得到的序列与 collect_list 完全相同...有没有办法解释原因这可能会发生吗?另外,检查 collect_list 似乎只在一个日期内用多个值弄乱那些情况......这是否意味着 collect_list 也保持顺序? 在您的代码中,您在 collect_list() 之前对整个数据集进行排序,所以是的。但这不是必需的,在将日期和值都收集到列表中之后,对生成的元组列表进行排序会更有效。 只是为了澄清...对列进行排序并在已排序的列上使用 collect_list 会保留顺序吗? 分布式系统中的顺序通常是没有意义的,因此除非每个id的值都在一个分区中,否则无法保证正确的顺序。 这个答案现在已经相当老了,我认为引入 array_sort 正如其他答案所描述的那样,这是最好的方法,因为它不需要 UDF 的开销。跨度> 【参考方案3】:

您可以使用sort_array 函数。如果您将日期和值都收集为一个列表,您可以使用 sort_array 对结果列进行排序,并只保留您需要的列。

import operator
import pyspark.sql.functions as F

grouped_df = input_df.groupby("id") \
               .agg(F.sort_array(F.collect_list(F.struct("date", "value"))) \
.alias("collected_list")) \
.withColumn("sorted_list",col("collected_list.value")) \
.drop("collected_list")
.show(truncate=False)

+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+ ```````

【讨论】:

非常感谢。我找到了 Window.patitionBy 然后获取最大行不能对大数据执行。您的解决方案要快约 200 倍。 是的,这在 scala 中也更快:grouped_df = input_df.groupBy("id").agg(sort_array(collect_list(struct("date", "value"))).alias ("collected_list")).withColumn("sorted_list", col("collected_list.value")) .drop("collected_list") 我不知道 Spark 将 collect_list.value 这个概念理解为对应字段值的数组。不错!【参考方案4】:

这个问题是针对 PySpark 的,但对于 Scala Spark 也有帮助。

让我们准备测试数据框:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame, Row, SparkSession
import org.apache.spark.sql.expressions. Window, UserDefinedFunction

import java.sql.Date
import java.time.LocalDate

val spark: SparkSession = ...

// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
  (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
  (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
  (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
  (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
  (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
  (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
  (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)

// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
  .toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id|      date|value|
//+---+----------+-----+
//|  1|2014-01-03|   10|
//|  1|2014-01-04|    5|
//|  1|2014-01-05|   15|
//|  1|2014-01-06|   20|
//|  2|2014-02-10|  100|
//|  2|2014-02-11|  500|
//|  2|2014-02-15| 1500|
//+---+----------+-----+

使用UDF

// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
  .agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id|          date_value|
// +---+--------------------+
// |  1|[[2014-01-03,10],...|
// |  2|[[2014-02-10,100]...|
// +---+--------------------+

// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => 
  rows.map  case Row(date: Date, value: Int) => (date, value) 
    .sortBy  case (date, value) => date 
    .map  case (date, value) => value 
)

// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id|      value_list|
// +---+----------------+
// |  1| [10, 5, 15, 20]|
// |  2|[100, 500, 1500]|
// +---+----------------+

使用窗口

val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id|      date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//|  1|2014-01-03|   10|                 [10]|
//|  1|2014-01-04|    5|              [10, 5]|
//|  1|2014-01-05|   15|          [10, 5, 15]|
//|  1|2014-01-06|   20|      [10, 5, 15, 20]|
//|  2|2014-02-10|  100|                [100]|
//|  2|2014-02-11|  500|           [100, 500]|
//|  2|2014-02-15| 1500|     [100, 500, 1500]|
//+---+----------+-----+---------------------+

val r2 = sortedDf.groupBy(col("id"))
  .agg(max("values_sorted_by_date").as("value_list")) 
r2.show()
//+---+----------------+
//| id|      value_list|
//+---+----------------+
//|  1| [10, 5, 15, 20]|
//|  2|[100, 500, 1500]|
//+---+----------------+

【讨论】:

是否可以在没有窗口或 udf 的情况下通过组合 explode、group by、order by 来完成此操作?【参考方案5】:

为了确保对每个 id 进行排序,我们可以使用 sortWithinPartitions:

from pyspark.sql import functions as F
ordered_df = (
    input_df
        .repartition(input_df.id)
        .sortWithinPartitions(['date'])


)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

【讨论】:

分组是在排序之后发生的。排序顺序会分步保留在组中吗? AFAIK 没有这样的保证【参考方案6】:

我尝试了 TMichel 方法,但对我不起作用。当我执行最大聚合时,我没有取回列表的最高值。所以对我有用的是:

def max_n_values(df, key, col_name, number):
    '''
    Returns the max n values of a spark dataframe
    partitioned by the key and ranked by the col_name
    '''
    w2 = Window.partitionBy(key).orderBy(f.col(col_name).desc())
    output = df.select('*',
                       f.row_number().over(w2).alias('rank')).filter(
                           f.col('rank') <= number).drop('rank')
    return output

def col_list(df, key, col_to_collect, name, score):
    w = Window.partitionBy(key).orderBy(f.col(score).desc())

    list_df = df.withColumn(name, f.collect_set(col_to_collect).over(w))
    size_df = list_df.withColumn('size', f.size(name))
    output = max_n_values(df=size_df,
                               key=key,
                               col_name='size',
                               number=1)
    return output

【讨论】:

我认为解释一下这对您的工作原理以及与已接受答案的区别可能会很有用 当我尝试 Tmichel 的方法时,最大值不起作用。我没有取回包含最多元素的列表,而是取回了随机列表。所以我所做的是我创建了一个新列来测量大小并获得 eahc 分区的最高值。希望这是有道理的!【参考方案7】:

从 Spark 2.4 开始,@mtoto 的答案中创建的 collect_list(ArrayType) 可以使用 SparkSQL 的内置函数 transform 和 array_sort 进行后处理(不需要 udf):

from pyspark.sql.functions import collect_list, expr, struct

df.groupby('id') \
  .agg(collect_list(struct('date','value')).alias('value_list')) \
  .withColumn('value_list', expr('transform(array_sort(value_list), x -> x.value)')) \
  .show()
+---+----------------+
| id|      value_list|
+---+----------------+
|  1| [10, 5, 15, 20]|
|  2|[100, 500, 1500]|
+---+----------------+ 

注意:如果需要降序,请将array_sort(value_list) 更改为sort_array(value_list, False)

警告:如果项目(在 collect_list 中)必须按多个字段(列)以混合顺序(即orderBy('col1', desc('col2')))排序,则 array_sort() 和 sort_array() 将不起作用。

【讨论】:

【参考方案8】:

在 Spark SQL 世界中,这个问题的答案是:

SELECT 
browser, max(list)
from (
  SELECT
    id,
    COLLECT_LIST(value) OVER (PARTITION BY id ORDER BY date DESC) as list
  FROM browser_count
  GROUP BYid, value, date) 
Group by browser;

【讨论】:

【参考方案9】:

如果你想在这里使用 spark sql,你可以如何实现这一点。假设表名(或临时视图)为temp_table

select
t1.id,
collect_list(value) as value_list
(Select * from temp_table order by id,date) t1
group by 1

【讨论】:

【参考方案10】:

作为ShadyStego 所说的补充,我一直在Spark 上测试sortWithinPartitions 和GroupBy 的用法,发现它的性能比Window 函数或UDF 好得多。尽管如此,使用此方法时,每个分区都会出现一次错误排序的问题,但可以轻松解决。我在这里展示它Spark (pySpark) groupBy misordering first element on collect_list。

此方法在大型 DataFrame 上特别有用,但如果驱动程序内存不足,则可能需要大量分区。

【讨论】:

以上是关于collect_list 通过保留基于另一个变量的顺序的主要内容,如果未能解决你的问题,请参考以下文章

通过 Spark SQL 实现 `collect_list`

基于spark中的列组合数据

基于另一个数组保留二维数组的每一行中的元素

Scala变量

是否可以更改一个函数(如计数器)中的全局变量,然后保留该值以供另一个函数使用? [重复]

如何通过phpunit上的所有测试保留会话?