您将如何限制 Spark 中每个分组键要处理的记录数? (对于倾斜的数据)

Posted

技术标签:

【中文标题】您将如何限制 Spark 中每个分组键要处理的记录数? (对于倾斜的数据)【英文标题】:How would you limit the number of records to process per grouped key in spark? (for skewed data) 【发布时间】:2019-07-24 22:09:36 【问题描述】:

我有两个大型数据集。有多个相同 ID 的分组。每组都有一个分数。我正在尝试将分数广播到每个组中的每个 id。但我有一个很好的约束,我不关心超过 1000 个 id 的组。

不幸的是,Spark 一直在阅读完整的分组。我似乎无法找到一种方法来降低限制,以便 Spark 最多只能读取 1000 条记录,如果还有更多则放弃。

到目前为止,我已经尝试过:

def run: Unit = 
    // ...

    val scores: RDD[(GroupId, Score)] = readScores(...)
    val data: RDD[(GroupId, Id)] = readData(...)

    val idToScore: RDD[(Id, Score)] = scores.cogroup(data)
      .flatMap(maxIdsPerGroupFilter(1000))

    // ...


def maxIdsPerGroupFilter(maxIds: Int)(t: (GroupId, (Iterable[Score], Iterable[Id]))): Iterator[(Id, Score)] = 
  t match 
    case (groupId: GroupId, (scores: Iterable[Score], ids: Iterable[Id])) =>
      if (!scores.iterator.hasNext) 
        return Iterator.empty
      
      val score: Score = scores.iterator.next()
      val iter = ids.iterator

      val uniqueIds: mutable.HashSet[Id] = new mutable.HashSet[Id]
      while (iter.hasNext) 
        uniqueIds.add(iter.next())
        if (uniqueIds.size > maxIds) 
          return Iterator.empty
        
      
      uniqueIds.map((_, score)).iterator
  

(即使过滤器函数只返回空迭代器的变体,Spark 仍然坚持读取所有数据)

这样做的副作用是,由于某些组的 id 太多,我的数据有很多偏差,并且在处理完整的数据时,工作永远无法完成。

我希望 reduce 端只读取它需要的数据,而不是因为数据倾斜而崩溃。

我有一种感觉,我需要以某种方式创建一个能够降低 limit 或 take 子句的转换,但我不知道如何。

【问题讨论】:

【参考方案1】:

我们不能在分组数据中使用 count() 过滤掉那些记录超过 1k 的组吗?

或者,如果您想让那些组也有超过 1k 条记录但只能选择 1k 条记录,那么在 spark sql 查询中您可以使用 ROW_NUMBER() OVER (PARTITION BY id ORDER BY someColumn DESC) AS rn 然后设置条件 rn

【讨论】:

以上是关于您将如何限制 Spark 中每个分组键要处理的记录数? (对于倾斜的数据)的主要内容,如果未能解决你的问题,请参考以下文章

获取每组分组结果的前 n 条记录

获取每组分组结果的前 n 条记录

MySQL:如何查询出每个分组中的 top n 条记录?

sql中根据表中一个字段分组分别统计每个分组的记录数

Spark:按组对记录进行排序?

mysql - 为每组分组的 SQL 结果获取具有最大值的记录[重复]