如何将数组传递给 Spark (UDAF) 中的用户定义聚合函数

Posted

技术标签:

【中文标题】如何将数组传递给 Spark (UDAF) 中的用户定义聚合函数【英文标题】:How to pass an array to an User Defined Aggregation Function in Spark (UDAF) 【发布时间】:2019-05-31 09:25:42 【问题描述】:

我想在 UDAF 中传递一个数组作为输入模式。

我给出的例子很简单,它只是将 2 个向量相加。实际上我的用例更复杂,我需要使用 UDAF。

import sc.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._

val df = Seq(
  (1, Array(10.2, 12.3, 11.2)),
  (1, Array(11.2, 12.6, 10.8)),
  (2, Array(12.1, 11.2, 10.1)),
  (2, Array(10.1, 16.0, 9.3)) 
  ).toDF("siteId", "bidRevenue")


class BidAggregatorBySiteId() extends UserDefinedAggregateFunction 

  def inputSchema: StructType = StructType(Array(StructField("bidRevenue", ArrayType(DoubleType))))

  def bufferSchema = StructType(Array(StructField("sumArray", ArrayType(DoubleType))))

  def dataType: DataType = ArrayType(DoubleType)

  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = 
      buffer.update(0, Array(0.0, 0.0, 0.0))
      

  def update(buffer: MutableAggregationBuffer, input: Row) = 
      val seqBuffer = buffer(0).asInstanceOf[IndexedSeq[Double]]
      val seqInput = input(0).asInstanceOf[IndexedSeq[Double]]
      buffer(0) = seqBuffer.zip(seqInput).map case (x, y) => x + y 
  

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = 
     val seqBuffer1 = buffer1(0).asInstanceOf[IndexedSeq[Double]]
     val seqBuffer2 = buffer2(0).asInstanceOf[IndexedSeq[Double]]
     buffer1(0) = seqBuffer1.zip(seqBuffer2).map case (x, y) => x + y 
  

  def evaluate(buffer: Row) =  
    buffer
  

val fun = new BidAggregatorBySiteId()

df.select($"siteId", $"bidRevenue" cast(ArrayType(DoubleType)))
.groupBy("siteId").agg(fun($"bidRevenue"))
.show

“显示”操作之前的所有转换都可以正常工作。但是节目引发了错误:

scala.MatchError: [WrappedArray(21.4, 24.9, 22.0)](属于 org.apache.spark.sql.execution.aggregate.InputAggregationBuffer 类) 在 org.apache.spark.sql.catalyst.CatalystTypeConverters$ArrayConverter.toCatalystImpl(CatalystTypeConverters.scala:160)

我的数据框的结构是:

root
 |-- siteId: integer (nullable = false)
 |-- bidRevenue: array (nullable = true)
 |    |-- element: double (containsNull = true)

df.dtypes = Array[(String, String)] = Array(("siteId", "IntegerType"), ("bidRevenue", "ArrayType(DoubleType,true)"))

为您提供宝贵的帮助。

【问题讨论】:

【参考方案1】:
def evaluate(buffer: Row): Any

一旦一个组被完全处理以获得最终结果,就会调用上述方法。 当您仅初始化和更新缓冲区的第 0 个索引时

i.e. buffer(0)  

因此您需要在最后返回第 0 个索引值,因为您的聚合结果存储在 0 索引处。

  def evaluate(buffer: Row) = 
    buffer.get(0)
  

上面对evaluate()方法的修改会导致:

// +------+---------------------------------+
// |siteId|bidaggregatorbysiteid(bidRevenue)|
// +------+---------------------------------+
// |     1|               [21.4, 24.9, 22.0]|
// |     2|               [22.2, 27.2, 19.4]|
// +------+---------------------------------+

【讨论】:

非常感谢。我没有看到明显的东西,而你看到了。 @FrançoisBagaïni 没问题,如果您能对答案投票或将其标记为解决方案,那就太好了。 对不起,我不能为答案投票,因为我没有足够的声誉

以上是关于如何将数组传递给 Spark (UDAF) 中的用户定义聚合函数的主要内容,如果未能解决你的问题,请参考以下文章

spark-streaming scala:如何将字符串数组传递给过滤器?

Spark - 将通用数组传递给 GenericRowWithSchema

何时合并发生在Spark中的用户定义聚合函数UDAF中

使用 ArrayType 作为 bufferSchema 的 Spark UDAF 性能问题

如何在 Spark UDAF 中实现 fastutils 映射?

Scala 中的 Spark SQL(v2.0) UDAF 返回空字符串