Spark SQL 的累积积 UDF
Posted
技术标签:
【中文标题】Spark SQL 的累积积 UDF【英文标题】:Cumulative product UDF for Spark SQL 【发布时间】:2020-04-09 16:33:28 【问题描述】:我在其他帖子中看到过为数据框完成此操作:https://***.com/a/52992212/4080521
但我想弄清楚如何为累积产品编写 udf。
假设我有一个非常基本的表格
Input data:
+----+
| val|
+----+
| 1 |
| 2 |
| 3 |
+----+
如果我想取这个总和,我可以简单地做类似的事情
sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)
这很有效,因为 SUM 是一个预定义的函数。
我如何为乘法定义类似的东西(或者我自己如何在UDF
中实现求和)?
尝试以下方法
sparkSession.createOrReplaceTempView("_Period0")
val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)
spark.sql("""Select prod(table.vals) from table""").show(100, false)
我收到以下错误:
Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)
显然每个特定的单元格不是一个数组,但似乎 udf 需要接受一个数组来执行聚合。 spark sql 甚至可以吗?
【问题讨论】:
How to define and use a User-Defined Aggregate Function in Spark SQL?的可能重复 【参考方案1】:可以通过UserDefinedAggregateFunction
实现
您需要定义几个函数来处理输入和缓冲区值。
使用 double 作为类型的产品函数的快速示例:
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class myUDAF extends UserDefinedAggregateFunction
// inputSchema for the function
override def inputSchema: StructType =
new StructType().add("val", DoubleType, nullable = true)
//Schema for the inner UDAF buffer, in the product case, you just need an accumulator
override def bufferSchema: StructType = StructType(StructField("accumulated", DoubleType) :: Nil)
//OutputDataType
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
//Initicla buffer value 1 for product
override def initialize(buffer: MutableAggregationBuffer) = buffer(0) = 1.0
//How to update the buffer, for product you just need to perform a product between the two elements (buffer & input)
override def update(buffer: MutableAggregationBuffer, input: Row) =
buffer(0) = buffer.getAs[Double](0) * input.getAs[Double](0)
//Merge results with the previous buffered value (product as well here)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
buffer1(0) = buffer1.getAs[Double](0) * buffer2.getAs[Double](0)
//Function on how to return the value
override def evaluate(buffer: Row) = buffer.getAs[Double](0)
然后您可以注册函数,就像使用任何其他 UDF 一样:
spark.udf.register("prod", new myUDAF)
结果
scala> spark.sql("Select prod(val) from table").show
+-----------+
|myudaf(val)|
+-----------+
| 6.0|
+-----------+
您可以找到更多文档here
【讨论】:
以上是关于Spark SQL 的累积积 UDF的主要内容,如果未能解决你的问题,请参考以下文章