Spark之UDAF
Posted yszd
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Spark之UDAF相关的知识,希望对你有一定的参考价值。
1 import org.apache.spark.sql.{Row, SparkSession} 2 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 3 import org.apache.spark.sql.types._ 4 5 /** 6 * Created by zhen on 2018/11/26. 7 */ 8 object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction{ 9 //聚合函数输入数据结构 10 override def inputSchema:StructType = StructType(StructField("input", LongType) :: Nil) 11 12 //缓存区数据结构 13 override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) 14 15 //结果数据结构 16 override def dataType : DataType = DoubleType 17 18 // 是否具有唯一性 19 override def deterministic : Boolean = true 20 21 //初始化 22 override def initialize(buffer : MutableAggregationBuffer) : Unit = { 23 buffer(0) = 0L 24 buffer(1) = 0L 25 } 26 27 //数据处理 : 必写,其它方法可选,使用默认 28 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 29 if(input.isNullAt(0)) return 30 buffer(0) = buffer.getLong(0) + input.getLong(0) //求和 31 buffer(1) = buffer.getLong(1) + 1 //计数 32 } 33 34 //合并 35 override def merge(bufferLeft: MutableAggregationBuffer, bufferRight: Row): Unit ={ 36 bufferLeft(0) = bufferLeft.getLong(0) + bufferRight.getLong(0) 37 bufferLeft(1) = bufferLeft.getLong(1) + bufferRight.getLong(1) 38 } 39 40 //计算结果 41 override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1) 42 } 43 object SparkUdaf { 44 def main(args: Array[String]) { 45 val spark = SparkSession 46 .builder() 47 .appName("udaf") 48 .master("local[2]") 49 .getOrCreate() 50 51 spark.read.json("E:/BDS/newsparkml/src/udaf.json").createOrReplaceTempView("user") 52 spark.udf.register("average", AverageUserDefinedAggregateFunction) 53 spark.sql("select count(*) count,average(age) avg_age from user").show() 54 55 } 56 }
结果:
以上是关于Spark之UDAF的主要内容,如果未能解决你的问题,请参考以下文章
Scala 中的 Spark SQL(v2.0) UDAF 返回空字符串