spark的udf和udaf的注册
Posted jeasonchen001
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了spark的udf和udaf的注册相关的知识,希望对你有一定的参考价值。
spark的udf和udaf的注册
一、udf
spark.udf.register("addName", (x: String) => {
"name: " + x
})
二、udaf
- 弱类型的自定义聚合函数 是不安全的
package com.huawei.appgallery.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* author:Chen
* 弱类型自定义聚合函数
* date:2020/2/12 14:29
*/
object MyAverage extends UserDefinedAggregateFunction {
//聚合后的输入数据类型
override def inputSchema: StructType = {
StructType(StructField("name", StringType, nullable = true) :: StructField("salary", LongType, nullable = false) :: Nil)
}
//聚合时缓存中的数据类型
override def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
//聚合后输出的数据类型
override def dataType: DataType = DoubleType
//数据一致性
override def deterministic: Boolean = true
//初始化缓存中的数据
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//更新同一分区缓存中的数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(1)) {
buffer(0) = buffer.getLong(0) + input.getLong(1)
}
buffer(1) = buffer.getLong(1) + 1
}
/**
* 合并不同分区中的缓存数据
*
* @param buffer1 MutableAggregationBuffer时要操作的buffer,可变的
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//对merge后的缓存数据做最后的计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(1) match {
case 0 => 0D
case _ => buffer.getLong(0) / buffer.getLong(1).toDouble
}
}
}
def main(args: Array[String]): Unit = {
//1
spark.udf.register("myAverage", MyAverage)
val lineDS: Dataset[String] = spark.read.textFile("C:\Users\ASUS\Desktop\test2_12.txt")
//dataset的schame设置
import spark.implicits._ //必须隐式转换
val employeeDS: Dataset[Employee] = lineDS.map(line => {
val items = line.split(" ")
Employee(items(0), items(1).toLong)
})
employeeDS.createOrReplaceTempView("view_employee")
val averageDF = spark.sql(
"""
|select myAverage(name,salary) as avg_salary from view_employee
""".stripMargin)
averageDF.show(false)
}
- 强类型的自定义聚合函数 程序运行时候会检查数据的类型,是安全的
package com.huawei.appgallery.udf
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* author:Chen
* 继承的包是org.apache.spark.sql.expressions.Aggregator
* 不是org.apache.spark.Aggregator
* 指定泛型
* inputschema的类型
* buffer的类型
* 输出的类型
* date:2020/2/12 14:30
*/
case class Employee(name: String, salary: Long)
case class Buffer(var sum: Long, var count: Long)
object MyAverage2 extends Aggregator[Employee, Buffer, Double] {
//相当于弱类型自定义聚合函数中的initialize
override def zero = Buffer(0L, 0L)
//相当于弱类型自定义聚合函数中的update,统一分区
override def reduce(b: Buffer, a: Employee): Buffer = {
//判断a对象中的是否为空
if (!a.salary.isNaN) {
b.sum = b.sum + a.salary
}
b.count += 1L
b
}
//相当于弱类型自定义聚合函数中merge,不同分区
override def merge(b1: Buffer, b2: Buffer): Buffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
//相当于弱类型自定义聚合函数中的evaluate,计算
override def finish(reduction: Buffer): Double = {
reduction.count match {
case 0L => 0D
case _ => reduction.sum / reduction.count.toDouble
}
}
//指定中间值Buffer的编码器类型 强类型自定义聚合函数的强类型体现在这里
override def bufferEncoder = {
Encoders.product[Buffer]
}
//指定结果的编码器类型 强类型自定义聚合函数的类型定义
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
//dataset引进了新的序列化的编码方式Encoder[T]代替之前的Java编码和kryo编码
def main(args: Array[String]): Unit = {
//2
val lineDS: Dataset[String] = spark.read.textFile("C:\Users\ASUS\Desktop\test2_12.txt")
//dataset的schame设置
import spark.implicits._
val employeeDS: Dataset[Employee] = lineDS.map(line => {
val items = line.split(" ")
Employee(items(0), items(1).toLong)
})
employeeDS.show(false)
val myAverage2 = MyAverage2.toColumn.name("myAverage")
val resultDF = employeeDS.select(myAverage2) **//使用的时候必须是强类型的dataset,不能是弱类型的dataframe,不然会报错
resultDF.show(false)**
}
以上是关于spark的udf和udaf的注册的主要内容,如果未能解决你的问题,请参考以下文章