Spark SQL 的累积积 UDF

Posted

技术标签:

【中文标题】Spark SQL 的累积积 UDF【英文标题】:Cumulative product UDF for Spark SQL 【发布时间】:2020-04-09 16:33:28 【问题描述】:

我在其他帖子中看到过为数据框完成此操作:https://***.com/a/52992212/4080521

但我想弄清楚如何为累积产品编写 udf。

假设我有一个非常基本的表格

Input data:
+----+
| val|
+----+
| 1  |
| 2  |
| 3  |
+----+

如果我想取这个总和,我可以简单地做类似的事情

sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)

这很有效,因为 SUM 是一个预定义的函数。

我如何为乘法定义类似的东西(或者我自己如何在UDF 中实现求和)?

尝试以下方法

sparkSession.createOrReplaceTempView("_Period0")

val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)

spark.sql("""Select prod(table.vals) from table""").show(100, false)

我收到以下错误:

Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)

显然每个特定的单元格不是一个数组,但似乎 udf 需要接受一个数组来执行聚合。 spark sql 甚至可以吗?

【问题讨论】:

How to define and use a User-Defined Aggregate Function in Spark SQL?的可能重复 【参考方案1】:

可以通过UserDefinedAggregateFunction实现 您需要定义几个函数来处理输入和缓冲区值。

使用 double 作为类型的产品函数的快速示例:

  import org.apache.spark.sql.expressions.MutableAggregationBuffer
  import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
  import org.apache.spark.sql.Row
  import org.apache.spark.sql.types._


    class myUDAF extends UserDefinedAggregateFunction 

      // inputSchema for the function
      override def inputSchema: StructType = 
        new StructType().add("val", DoubleType, nullable = true)
      

     //Schema for the inner UDAF buffer, in the product case, you just need an accumulator
     override def bufferSchema: StructType = StructType(StructField("accumulated", DoubleType) :: Nil)

    //OutputDataType
    override def dataType: DataType = DoubleType

    override def deterministic: Boolean = true

    //Initicla buffer value 1 for product
    override def initialize(buffer: MutableAggregationBuffer) = buffer(0) = 1.0

    //How to update the buffer, for product you just need to perform a product between the two elements (buffer & input)
    override def update(buffer: MutableAggregationBuffer, input: Row) = 
        buffer(0) = buffer.getAs[Double](0) * input.getAs[Double](0)
      

      //Merge results with the previous buffered value (product as well here)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
        buffer1(0) = buffer1.getAs[Double](0) * buffer2.getAs[Double](0)
      

      //Function on how to return the value
      override def evaluate(buffer: Row) = buffer.getAs[Double](0)

    

然后您可以注册函数,就像使用任何其他 UDF 一样:

spark.udf.register("prod", new myUDAF)

结果

scala> spark.sql("Select prod(val) from table").show
+-----------+
|myudaf(val)|
+-----------+
|        6.0|
+-----------+

您可以找到更多文档here

【讨论】:

以上是关于Spark SQL 的累积积 UDF的主要内容,如果未能解决你的问题,请参考以下文章

Spark SQL 的累积非重复计数

spark的udf和udaf的注册

算法总结之 不包含本位置的累乘数组

《程序员代码面试指南》第八章 数组和矩阵问题 不包含本位置值的累乘数组

[算法]不包含本位置值的累乘数组

java普通类怎么调用service曾的累