SparkSQL自定义无类型聚合函数

Posted sysocjs

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SparkSQL自定义无类型聚合函数相关的知识,希望对你有一定的参考价值。

准备数据:

Michael,3000
Andy,4500
Justin,3500
Betral,4000

一、定义自定义无类型聚合函数

        想要自定义无类型聚合函数,那必须得继承org.spark.sql.expressions.UserDefinedAggregateFunction,然后重写父类得抽象变量和成员方法。

package com.cjs
 
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.MutableAggregationBuffer, UserDefinedAggregateFunction
import org.apache.spark.sql.types._
 
object UDFMyAverage extends UserDefinedAggregateFunction
    //定义输入参数的数据类型
    override def inputSchema: StructType = StructType(StructField("inputColumn", LongType)::Nil)
    //定义缓冲器的数据结构类型,缓冲器用于计算,这里定义了两个数据变量:sum和count
    override def bufferSchema: StructType = StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
 
    //聚合函数返回的数据类型
    override def dataType: DataType = DoubleType
 
    override def deterministic: Boolean = true
    //初始化缓冲器
    override def initialize(buffer: MutableAggregationBuffer): Unit = 
        //buffer本质上也是一个Row对象,所以也可以使用下标的方式获取它的元素
        buffer(0) = 0L  //这里第一个元素是上面定义的sum
        buffer(1) = 0L  //这里第二个元素是上面定义的sount
    
 
    //update方法用于将输入数据跟缓冲器数据进行计算,这里是一个累加的作用
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
    
 
    //buffer1是主缓冲器,储存的是目前各个节点的部分计算结果;buffer2是分布式中执行任务的各个节点的“主”缓冲器;
    // merge方法作用是将各个节点的计算结果做一个聚合,其实可以理解为分布式的update的方法,buffer2相当于input:Row
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    
 
    //计算最终结果
    override def evaluate(buffer: Row): Any = 
        buffer.getLong(0).toDouble/buffer.getLong(1)
    

二、使用自定义无类型聚合函数

package com.cjs
 
import org.apache.log4j.Level, Logger
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row, SparkSession
import org.apache.spark.sql.types.StringType, StructField, StructType
 
object TestMyAverage 
    def main(args: Array[String]): Unit = 
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
 
        val conf = new SparkConf()
            .set("spark.some.config.option","some-value")
            .set("spark.sql.warehouse.dir","file:///e:/tmp/spark-warehouse")
 
        val ss = SparkSession
            .builder()
            .config(conf)
            .appName("test-myAverage")
            .master("local[2]")
            .getOrCreate()
 
        import ss.implicits._
        val sc = ss.sparkContext
 
        val schemaString = "name,salary"
        val fileds = schemaString.split(",").map(filedName => StructField(filedName,StringType, nullable = true))
        val schemaStruct = StructType(fileds)
 
        val path = "E:\\IntelliJ Idea\\sparkSql_practice\\src\\main\\scala\\com\\cjs\\employee.txt"
        val empRDD = sc.textFile(path).map(_.split(",")).map(row=>Row(row(0),row(1)))
 
        val empDF = ss.createDataFrame(empRDD,schemaStruct)
        empDF.createOrReplaceTempView("emp")
//        ss.sql("select name, salary from emp limit 5").show()
        //想要在spark sql里使用无类型自定义聚合函数,那么就要先注册给自定义函数
        ss.udf.register("myAverage",UDFMyAverage)
 
//        empDF.show()
        ss.sql("select myAverage(salary) as average_salary from emp").show()
    
 

输出结果:

技术图片

以上是关于SparkSQL自定义无类型聚合函数的主要内容,如果未能解决你的问题,请参考以下文章

Spark学习之路 (十九)SparkSQL的自定义函数UDF

Spark学习之路 (十九)SparkSQL的自定义函数UDF[转]

如何在 Spark SQL 中定义和使用用户定义的聚合函数?

Spark篇---SparkSQL中自定义UDF和UDAF,开窗函数的应用

Spark SQL自定义函数

入门大数据---SparkSQL常用聚合函数