如何对 RDD 中的项目进行排名以建立连胜?
Posted
技术标签:
【中文标题】如何对 RDD 中的项目进行排名以建立连胜?【英文标题】:How can I rank items in a RDD to build a streak? 【发布时间】:2018-01-02 20:23:35 【问题描述】:我有一个包含如下数据的 RDD:(downloadId: String, date: LocalDate, downloadCount: Int)
。日期和下载 ID 是唯一的,下载计数是针对日期的。
我一直在尝试完成的是获取下载 ID 在所有下载 ID 中排名前 100 的连续天数(从当前日期倒数)。因此,如果给定的下载量在今天、昨天和前一天都在前 100 名中,那么它的连续性将是 3。
在 SQL 中,我想这可以使用窗口函数来解决。我见过类似的问题。 How to add a running count to rows in a 'streak' of consecutive days
(我对 Spark 比较陌生,但不确定如何映射减少 RDD 甚至开始解决这样的问题。)
更多信息,日期是最近 30 天,每天大约有 400 万个唯一下载 ID。
【问题讨论】:
【参考方案1】:我建议您使用 DataFrame,因为它们比 RDD 更容易使用。 Leo 的答案较短,但我找不到它在哪里过滤前 100 名下载,所以我决定也发布我的答案。它不依赖于窗口函数,但它受限于您想要连续运行的过去天数。既然你说你只使用最近 30 天的数据,那应该没问题。
作为第一步,我编写了一些代码来生成类似于您所描述的 DF。您不需要运行第一个块(如果这样做,请减少行数,除非您有一个集群可以尝试它,它会占用大量内存)。您可以看到如何将 RDD (theData
) 转换为 DF (baseData
)。你应该像我一样为它定义一个模式。
import java.time.LocalDate
import scala.util.Random
val maxId = 10000
val numRows = 15000000
val lastDate = LocalDate.of(2017, 12, 31)
// Generates the data. As a convenience for working with Dataframes, I converted the dates to epoch days.
val theData = sc.parallelize(1.to(numRows).map
_ =>
val id = Random.nextInt(maxId)
val nDownloads = Random.nextInt((id / 1000 + 1))
Row(id, lastDate.minusDays(Random.nextInt(30)).toEpochDay, nDownloads)
)
//Working with Dataframes is much simples, so I'll generate a DF named baseData from the RDD
val schema = StructType(
StructField("downloadId", IntegerType, false) ::
StructField("date", LongType, false) ::
StructField("downloadCount", IntegerType, false) :: Nil)
val baseData = sparkSession.sqlContext.createDataFrame(theData, schema)
.groupBy($"downloadId", $"date")
.agg(sum($"downloadCount").as("downloadCount"))
.cache()
现在您在名为 baseData
的 DF 中获得了所需的数据。下一步是将其限制在每天的前 100 位 - 在进行任何额外的繁重转换之前,您应该丢弃不使用的数据。
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame, Row
def filterOnlyTopN(data: DataFrame, n: Int = 100): DataFrame =
// For each day in the data, let's find the cutoff # of downloads to make it into the top N
val getTopNCutoff = udf((downloads: Seq[Long]) =>
val reverseSortedDownloads = downloads.sortBy- _
if (reverseSortedDownloads.length >= n)
reverseSortedDownloads.drop(n - 1).head
else
reverseSortedDownloads.last
)
val topNLimitsByDate = data.groupBy($"date").agg(collect_set($"downloadCount").as("downloads"))
.select($"date", getTopNCutoff($"downloads").as("cutoff"))
// And then, let's throw away the records below the top 100
data.join(topNLimitsByDate, Seq("date"))
.filter($"downloadCount" >= $"cutoff")
.drop("cutoff", "downloadCount")
val relevantData = filterOnlyTopN(baseData)
现在您已经有了只包含您需要的数据的relevantData
DF,您可以计算它们的连续性。我将没有条纹的 id 保留为条纹 0,您可以使用 streaks.filter($"streak" > lit(0))
过滤掉它们。
def getStreak(df: DataFrame, fromDate: Long): DataFrame =
val calcStreak = udf((dateList: Seq[Long]) =>
if (!dateList.contains(fromDate))
0
else
val relevantDates = dateList.sortBy- _ // Order the dates descending
.dropWhile(_ != fromDate) // And drop everything until we find the starting day we are interested in
if (relevantDates.length == 1) // If there's only one day left, it's a one day streak
1
else // Otherwise, let's count the streak length (this works if no dates are left, too - but not with only 1 day)
relevantDates.sliding(2) // Take days by pairs
.takeWhiletwoDays => twoDays(1) == twoDays(0) - 1 // While the pair is of consecutive days
.length+1 // And the streak will be the number of consecutive pairs + 1 (the initial day of the streak)
)
df.groupBy($"downloadId").agg(collect_list($"date").as("dates")).select($"downloadId", calcStreak($"dates").as("streak"))
val streaks = getStreak(relevantData, lastDate.toEpochDay)
streaks.show()
+------------+--------+
| downloadId | streak |
+------------+--------+
| 8086 | 0 |
| 9852 | 0 |
| 7253 | 0 |
| 9376 | 0 |
| 7833 | 0 |
| 9465 | 1 |
| 7880 | 0 |
| 9900 | 1 |
| 7993 | 0 |
| 9427 | 1 |
| 8389 | 1 |
| 8638 | 1 |
| 8592 | 1 |
| 6397 | 0 |
| 7754 | 1 |
| 7982 | 0 |
| 7554 | 0 |
| 6357 | 1 |
| 7340 | 0 |
| 6336 | 0 |
+------------+--------+
你有 streaks
DF 和你需要的数据。
【讨论】:
【参考方案2】:使用列出的 PostgreSQL 链接中的类似方法,您也可以在 Spark 中应用 Window
函数。 Spark 的 DataFrame API 没有 java.time.LocalDate
的编码器,因此您需要将其转换为 java.sql.Date
。
步骤如下:首先,将 RDD 转换为支持日期格式的 DataFrame;接下来,创建一个UDF
来计算baseDate
,它需要一个日期和一个按id 时间顺序的行号(使用Window 函数生成)作为参数。另一个 Window 函数用于计算 per-id-baseDate 行数,这是想要的条纹值:
import java.time.LocalDate
val rdd = sc.parallelize(Seq(
(1, LocalDate.parse("2017-12-13"), 2),
(1, LocalDate.parse("2017-12-16"), 1),
(1, LocalDate.parse("2017-12-17"), 1),
(1, LocalDate.parse("2017-12-18"), 2),
(1, LocalDate.parse("2017-12-20"), 1),
(1, LocalDate.parse("2017-12-21"), 3),
(2, LocalDate.parse("2017-12-15"), 2),
(2, LocalDate.parse("2017-12-16"), 1),
(2, LocalDate.parse("2017-12-19"), 1),
(2, LocalDate.parse("2017-12-20"), 1),
(2, LocalDate.parse("2017-12-21"), 2),
(2, LocalDate.parse("2017-12-23"), 1)
))
val df = rdd.map case (id, date, count) => (id, java.sql.Date.valueOf(date), count) .
toDF("downloadId", "date", "downloadCount")
def baseDate = udf( (d: java.sql.Date, n: Long) =>
new java.sql.Date(new java.util.Date(d.getTime).getTime - n * 24 * 60 * 60 * 1000)
)
import org.apache.spark.sql.expressions.Window
val dfStreak = df.withColumn("rowNum", row_number.over(
Window.partitionBy($"downloadId").orderBy($"date")
)
).withColumn(
"baseDate", baseDate($"date", $"rowNum")
).select(
$"downloadId", $"date", $"downloadCount", row_number.over(
Window.partitionBy($"downloadId", $"baseDate").orderBy($"date")
).as("streak")
).orderBy($"downloadId", $"date")
dfStreak.show
+----------+----------+-------------+------+
|downloadId| date|downloadCount|streak|
+----------+----------+-------------+------+
| 1|2017-12-13| 2| 1|
| 1|2017-12-16| 1| 1|
| 1|2017-12-17| 1| 2|
| 1|2017-12-18| 2| 3|
| 1|2017-12-20| 1| 1|
| 1|2017-12-21| 3| 2|
| 2|2017-12-15| 2| 1|
| 2|2017-12-16| 1| 2|
| 2|2017-12-19| 1| 1|
| 2|2017-12-20| 1| 2|
| 2|2017-12-21| 2| 3|
| 2|2017-12-23| 1| 1|
+----------+----------+-------------+------+
【讨论】:
以上是关于如何对 RDD 中的项目进行排名以建立连胜?的主要内容,如果未能解决你的问题,请参考以下文章