在 Spark Scala 中的列上运行累积/迭代 Costum 方法

Posted

技术标签:

【中文标题】在 Spark Scala 中的列上运行累积/迭代 Costum 方法【英文标题】:Run a Cumulative/Iterative Costum Method on a Column in Spark Scala 【发布时间】:2017-08-15 14:54:14 【问题描述】:

您好,我是 Spark/Scala 的新手,我一直在尝试 - AKA 失败,基于特定的递归公式在 spark 数据框中创建列:

这里是伪代码。

someDf.col2[0] = 0

for i > 0
someDf.col2[i] = x * someDf.col1[i-1] + (1-x) * someDf.col2[i-1]

要深入了解更多细节,这是我的出发点: 该数据框是dates 和个人id 级别聚合的结果。

所有进一步的计算都必须针对特定的id 进行,并且必须考虑到前一周发生的情况。

为了说明这一点,我将值简化为 0 和 1,并删除了乘数 x1-x,我还将 col2 初始化为零。

var someDf = Seq(("2016-01-10 00:00:00.0","385608",0,0), 
         ("2016-01-17 00:00:00.0","385608",0,0),
         ("2016-01-24 00:00:00.0","385608",1,0),
         ("2016-01-31 00:00:00.0","385608",1,0),
         ("2016-02-07 00:00:00.0","385608",1,0),
         ("2016-02-14 00:00:00.0","385608",1,0),
         ("2016-01-17 00:00:00.0","105010",0,0),
         ("2016-01-24 00:00:00.0","105010",1,0),
         ("2016-01-31 00:00:00.0","105010",0,0),
         ("2016-02-07 00:00:00.0","105010",1,0)
        ).toDF("dates", "id", "col1","col2" )

someDf.show()
+--------------------+------+----+----+
|               dates|    id|col1|col2|
+--------------------+------+----+----+
|2016-01-10 00:00:...|385608|   0|   0|
|2016-01-17 00:00:...|385608|   0|   0|
|2016-01-24 00:00:...|385608|   1|   0|
|2016-01-31 00:00:...|385608|   1|   0|
|2016-02-07 00:00:...|385608|   1|   0|
|2016-02-14 00:00:...|385608|   1|   0|
+--------------------+------+----+----+
|2016-01-17 00:00:...|105010|   0|   0|
|2016-01-24 00:00:...|105010|   1|   0|
|2016-01-31 00:00:...|105010|   0|   0|
|2016-02-07 00:00:...|105010|   1|   0|
+--------------------+------+----+----+

到目前为止我尝试过的与想要的

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val date_id_window = Window.partitionBy("id").orderBy(asc("dates")) 

someDf.withColumn("col2", lag($"col1",1 ).over(date_id_window) + 
lag($"col2",1 ).over(date_id_window) ).show() 
+--------------------+------+----+----+ / +--------------------+
|               dates|    id|col1|col2| / | what_col2_should_be|
+--------------------+------+----+----+ / +--------------------+
|2016-01-17 00:00:...|105010|   0|null| / |                   0| 
|2016-01-24 00:00:...|105010|   1|   0| / |                   0|
|2016-01-31 00:00:...|105010|   0|   1| / |                   1|
|2016-02-07 00:00:...|105010|   1|   0| / |                   1|
+-------------------------------------+ / +--------------------+
|2016-01-10 00:00:...|385608|   0|null| / |                   0|
|2016-01-17 00:00:...|385608|   0|   0| / |                   0|
|2016-01-24 00:00:...|385608|   1|   0| / |                   0|
|2016-01-31 00:00:...|385608|   1|   1| / |                   1|
|2016-02-07 00:00:...|385608|   1|   1| / |                   2|
|2016-02-14 00:00:...|385608|   1|   1| / |                   3|
+--------------------+------+----+----+ / +--------------------+

有没有办法用 Spark 数据帧来做到这一点,我已经看到了多个累积类型计算,但从不包括同一列,我认为问题是不考虑第 i-1 行的新计算值,而是使用旧的 i-1,它始终为 0。

任何帮助将不胜感激。

【问题讨论】:

【参考方案1】:

Dataset 应该可以正常工作:

val x = 0.1

case class Record(dates: String, id: String, col1: Int)

someDf.drop("col2").as[Record].groupByKey(_.id).flatMapGroups((_,  records) => 
  val sorted = records.toSeq.sortBy(_.dates)
  sorted.scanLeft((null: Record, 0.0))
    case ((_, col2), record) => (record, x * record.col1 + (1 - x) * col2)
  .tail
).select($"_1.*", $"_2".alias("col2"))

【讨论】:

【参考方案2】:

您可以将rowsBetween api 与您正在使用的Window 函数一起使用,并且您应该获得所需的输出

val date_id_window = Window.partitionBy("id").orderBy(asc("dates"))
someDf.withColumn("col2", sum(lag($"col1", 1).over(date_id_window)).over(date_id_window.rowsBetween(Long.MinValue, 0)))
  .withColumn("col2", when($"col2".isNull, lit(0)).otherwise($"col2"))
  .show()

给定输入dataframe

+--------------------+------+----+----+
|               dates|    id|col1|col2|
+--------------------+------+----+----+
|2016-01-10 00:00:...|385608|   0|   0|
|2016-01-17 00:00:...|385608|   0|   0|
|2016-01-24 00:00:...|385608|   1|   0|
|2016-01-31 00:00:...|385608|   1|   0|
|2016-02-07 00:00:...|385608|   1|   0|
|2016-02-14 00:00:...|385608|   1|   0|
|2016-01-17 00:00:...|105010|   0|   0|
|2016-01-24 00:00:...|105010|   1|   0|
|2016-01-31 00:00:...|105010|   0|   0|
|2016-02-07 00:00:...|105010|   1|   0|
+--------------------+------+----+----+

应用上述逻辑后,您应该有输出数据框

+--------------------+------+----+----+
|               dates|    id|col1|col2|
+--------------------+------+----+----+
|2016-01-17 00:00:...|105010|   0|   0|
|2016-01-24 00:00:...|105010|   1|   0|
|2016-01-31 00:00:...|105010|   0|   1|
|2016-02-07 00:00:...|105010|   1|   1|
|2016-01-10 00:00:...|385608|   0|   0|
|2016-01-17 00:00:...|385608|   0|   0|
|2016-01-24 00:00:...|385608|   1|   0|
|2016-01-31 00:00:...|385608|   1|   1|
|2016-02-07 00:00:...|385608|   1|   2|
|2016-02-14 00:00:...|385608|   1|   3|
+--------------------+------+----+----+

希望回答对你有帮助

【讨论】:

【参考方案3】:

您应该对数据框应用转换,而不是将其视为var。获得所需内容的一种方法是使用 Window 的 rowsBetween 对每个窗口分区内的行通过前一行(即行 -1)累计求和 col1 的值:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val window = Window.partitionBy("id").orderBy("dates").rowsBetween(Long.MinValue, -1)

val newDF = someDf.
  withColumn(
    "col2", sum($"col1").over(window)
  ).withColumn(
    "col2", when($"col2".isNull, 0).otherwise($"col2")
  ).orderBy("id", "dates")

newDF.show
+--------------------+------+----+----+
|               dates|    id|col1|col2|
+--------------------+------+----+----+
|2016-01-17 00:00:...|105010|   0|   0|
|2016-01-24 00:00:...|105010|   1|   0|
|2016-01-31 00:00:...|105010|   0|   1|
|2016-02-07 00:00:...|105010|   1|   1|
|2016-01-10 00:00:...|385608|   0|   0|
|2016-01-17 00:00:...|385608|   0|   0|
|2016-01-24 00:00:...|385608|   1|   0|
|2016-01-31 00:00:...|385608|   1|   1|
|2016-02-07 00:00:...|385608|   1|   2|
|2016-02-14 00:00:...|385608|   1|   3|
+--------------------+------+----+----+

【讨论】:

以上是关于在 Spark Scala 中的列上运行累积/迭代 Costum 方法的主要内容,如果未能解决你的问题,请参考以下文章

如何在倾斜列上重新分区 Spark scala 中的数据框?

最近 n_days 使用 groupby 在特定列上的累积总和

Spark 3.0 排序并应用于组 Scala/Java

如何规范化 spark (scala) 中的列中的全角字符

Scala(Spark)连接数据框中的列[重复]

遍历火花数据框中的列并计算最小值最大值