如何对 RDD 中的项目进行排名以建立连胜?

Posted

技术标签:

【中文标题】如何对 RDD 中的项目进行排名以建立连胜?【英文标题】:How can I rank items in a RDD to build a streak? 【发布时间】:2018-01-02 20:23:35 【问题描述】:

我有一个包含如下数据的 RDD:(downloadId: String, date: LocalDate, downloadCount: Int)。日期和下载 ID 是唯一的,下载计数是针对日期的。

我一直在尝试完成的是获取下载 ID 在所有下载 ID 中排名前 100 的连续天数(从当前日期倒数)。因此,如果给定的下载量在今天、昨天和前一天都在前 100 名中,那么它的连续性将是 3。

在 SQL 中,我想这可以使用窗口函数来解决。我见过类似的问题。 How to add a running count to rows in a 'streak' of consecutive days

(我对 Spark 比较陌生,但不确定如何映射减少 RDD 甚至开始解决这样的问题。)

更多信息,日期是最近 30 天,每天大约有 400 万个唯一下载 ID。

【问题讨论】:

【参考方案1】:

我建议您使用 DataFrame,因为它们比 RDD 更容易使用。 Leo 的答案较短,但我找不到它在哪里过滤前 100 名下载,所以我决定也发布我的答案。它不依赖于窗口函数,但它受限于您想要连续运行的过去天数。既然你说你只使用最近 30 天的数据,那应该没问题。

作为第一步,我编写了一些代码来生成类似于您所描述的 DF。您不需要运行第一个块(如果这样做,请减少行数,除非您有一个集群可以尝试它,它会占用大量内存)。您可以看到如何将 RDD (theData) 转换为 DF (baseData)。你应该像我一样为它定义一个模式。

import java.time.LocalDate
import scala.util.Random

val maxId = 10000
val numRows = 15000000
val lastDate = LocalDate.of(2017, 12, 31)

// Generates the data. As a convenience for working with Dataframes, I converted the dates to epoch days.

val theData = sc.parallelize(1.to(numRows).map
  _ => 
    val id = Random.nextInt(maxId)
    val nDownloads = Random.nextInt((id / 1000 + 1))
    Row(id, lastDate.minusDays(Random.nextInt(30)).toEpochDay, nDownloads)
  
)

//Working with Dataframes is much simples, so I'll generate a DF named baseData from the RDD

val schema = StructType(
    StructField("downloadId", IntegerType, false) ::
    StructField("date", LongType, false) ::
    StructField("downloadCount", IntegerType, false) :: Nil)

val baseData = sparkSession.sqlContext.createDataFrame(theData, schema)
  .groupBy($"downloadId", $"date")
    .agg(sum($"downloadCount").as("downloadCount"))
  .cache()

现在您在名为 baseData 的 DF 中获得了所需的数据。下一步是将其限制在每天的前 100 位 - 在进行任何额外的繁重转换之前,您应该丢弃不使用的数据。

import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame, Row

def filterOnlyTopN(data: DataFrame, n: Int = 100): DataFrame = 
  // For each day in the data, let's find the cutoff # of downloads to make it into the top N
  val getTopNCutoff = udf((downloads: Seq[Long]) => 
    val reverseSortedDownloads = downloads.sortBy- _ 
    if (reverseSortedDownloads.length >= n)
      reverseSortedDownloads.drop(n - 1).head
    else
      reverseSortedDownloads.last
  )

  val topNLimitsByDate = data.groupBy($"date").agg(collect_set($"downloadCount").as("downloads"))
          .select($"date", getTopNCutoff($"downloads").as("cutoff"))

  // And then, let's throw away the records below the top 100
  data.join(topNLimitsByDate, Seq("date"))
    .filter($"downloadCount" >= $"cutoff")
    .drop("cutoff", "downloadCount")


val relevantData = filterOnlyTopN(baseData)

现在您已经有了只包含您需要的数据的relevantData DF,您可以计算它们的连续性。我将没有条纹的 id 保留为条纹 0,您可以使用 streaks.filter($"streak" > lit(0)) 过滤掉它们。

def getStreak(df: DataFrame, fromDate: Long): DataFrame = 
  val calcStreak = udf((dateList: Seq[Long]) => 
    if (!dateList.contains(fromDate))
      0
    else 
      val relevantDates = dateList.sortBy- _              // Order the dates descending
        .dropWhile(_ != fromDate)        // And drop everything until we find the starting day we are interested in
      if (relevantDates.length == 1)     // If there's only one day left, it's a one day streak
        1
      else                               // Otherwise, let's count the streak length (this works if no dates are left, too - but not with only 1 day)
        relevantDates.sliding(2)         // Take days by pairs
          .takeWhiletwoDays => twoDays(1) == twoDays(0) - 1   // While the pair is of consecutive days
          .length+1                      // And the streak will be the number of consecutive pairs + 1 (the initial day of the streak)
    
  )
  df.groupBy($"downloadId").agg(collect_list($"date").as("dates")).select($"downloadId", calcStreak($"dates").as("streak"))

val streaks = getStreak(relevantData, lastDate.toEpochDay)
streaks.show()

+------------+--------+
| downloadId | streak |
+------------+--------+
|       8086 |      0 |
|       9852 |      0 |
|       7253 |      0 |
|       9376 |      0 |
|       7833 |      0 |
|       9465 |      1 |
|       7880 |      0 |
|       9900 |      1 |
|       7993 |      0 |
|       9427 |      1 |
|       8389 |      1 |
|       8638 |      1 |
|       8592 |      1 |
|       6397 |      0 |
|       7754 |      1 |
|       7982 |      0 |
|       7554 |      0 |
|       6357 |      1 |
|       7340 |      0 |
|       6336 |      0 |
+------------+--------+

你有 streaks DF 和你需要的数据。

【讨论】:

【参考方案2】:

使用列出的 PostgreSQL 链接中的类似方法,您也可以在 Spark 中应用 Window 函数。 Spark 的 DataFrame API 没有 java.time.LocalDate 的编码器,因此您需要将其转换为 java.sql.Date

步骤如下:首先,将 RDD 转换为支持日期格式的 DataFrame;接下来,创建一个UDF 来计算baseDate,它需要一个日期和一个按id 时间顺序的行号(使用Window 函数生成)作为参数。另一个 Window 函数用于计算 per-id-baseDate 行数,这是想要的条纹值:

import java.time.LocalDate

val rdd = sc.parallelize(Seq(
  (1, LocalDate.parse("2017-12-13"), 2),
  (1, LocalDate.parse("2017-12-16"), 1),
  (1, LocalDate.parse("2017-12-17"), 1),
  (1, LocalDate.parse("2017-12-18"), 2),
  (1, LocalDate.parse("2017-12-20"), 1),
  (1, LocalDate.parse("2017-12-21"), 3),
  (2, LocalDate.parse("2017-12-15"), 2),
  (2, LocalDate.parse("2017-12-16"), 1),
  (2, LocalDate.parse("2017-12-19"), 1),
  (2, LocalDate.parse("2017-12-20"), 1),
  (2, LocalDate.parse("2017-12-21"), 2),
  (2, LocalDate.parse("2017-12-23"), 1)
))

val df = rdd.map case (id, date, count) => (id, java.sql.Date.valueOf(date), count) .
  toDF("downloadId", "date", "downloadCount")

def baseDate = udf( (d: java.sql.Date, n: Long) =>
  new java.sql.Date(new java.util.Date(d.getTime).getTime - n * 24 * 60 * 60 * 1000)
)

import org.apache.spark.sql.expressions.Window

val dfStreak = df.withColumn("rowNum", row_number.over(
    Window.partitionBy($"downloadId").orderBy($"date")
  )
).withColumn(
  "baseDate", baseDate($"date", $"rowNum")
).select(
  $"downloadId", $"date", $"downloadCount", row_number.over(
    Window.partitionBy($"downloadId", $"baseDate").orderBy($"date")
  ).as("streak")
).orderBy($"downloadId", $"date")

dfStreak.show
+----------+----------+-------------+------+
|downloadId|      date|downloadCount|streak|
+----------+----------+-------------+------+
|         1|2017-12-13|            2|     1|
|         1|2017-12-16|            1|     1|
|         1|2017-12-17|            1|     2|
|         1|2017-12-18|            2|     3|
|         1|2017-12-20|            1|     1|
|         1|2017-12-21|            3|     2|
|         2|2017-12-15|            2|     1|
|         2|2017-12-16|            1|     2|
|         2|2017-12-19|            1|     1|
|         2|2017-12-20|            1|     2|
|         2|2017-12-21|            2|     3|
|         2|2017-12-23|            1|     1|
+----------+----------+-------------+------+

【讨论】:

以上是关于如何对 RDD 中的项目进行排名以建立连胜?的主要内容,如果未能解决你的问题,请参考以下文章

如何按 RDD 中的选定字段数进行分组,以查找基于这些字段的重复项

如何添加计数以对 SQL Hive 中的空值进行排名?

如何根据sklearn中的预测概率对实例进行排名

如何根据列中的一组行对数据框进行排名?

如何根据行日差和分区对 SQL 中的列进行排名?

Spark - 如何使用有状态映射器对已排序的 RDD 进行平面映射?