极简spark教程spark聚合函数
Posted 鱼摆摆
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了极简spark教程spark聚合函数相关的知识,希望对你有一定的参考价值。
聚合函数分为两类,一种是spark内置的常用聚合函数,一种是用户自定义聚合函数
UDAF
不带类型的UDAF【较常用】
- 继承UserDefinedAggregateFunction
- 定义输入数据的schema
- 定义缓存的数据结构
- 聚合函数返回值的数据类型
- 定义聚合函数的幂等性,一般为true
- 初始化缓存
- 更新缓存
- 合并缓存
- 计算结果
import org.apache.spark.SparkConf, SparkContext
import org.apache.spark.sql.Row, SparkSession
import org.apache.spark.sql.expressions.MutableAggregationBuffer, UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
object avg extends UserDefinedAggregateFunction
// 定义输入数据的schema,需要指定列名,但在实际使用中这里指定的列名没有意义
override def inputSchema: StructType = StructType(List(StructField("input", LongType)))
// 缓存的数据结构,bufferSchema定义了缓存的数据结构具有sum和count两个字段
override def bufferSchema: StructType = StructType(List(StructField("sum", LongType), StructField("count", LongType)))
// 聚合函数返回值的数据类型:返回值的类型必需与下面的evaluate返回类型一致
override def dataType: DataType = LongType
// 聚合函数的幂等性,相同输入总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓存:根据bufferSchema,缓存具有sum和count两个字段,这里会对sum和count两个变量的值进行初始化
// tips:缓存buffer是MutableAggregationBuffer类型,你可以简单理解buffer就是一个数组
// tips:在这里buffer是具有代表了sum和count数值的二元数组
override def initialize(buffer: MutableAggregationBuffer): Unit =
buffer(0) = 0L
buffer(1) = 0L
// 更新缓存:接受并处理输入数据,更新buffer
// tips:在实际处理中,输入数据是DataFrame,DataFrame是由多个Row组成的,每个Row会逐个传递给update,更新buffer中的值
// tips:必须对输入的input进行检查,防止input.getLong(i)出现越界报错ArrayIndexOutOfBoundsException
override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
if(input.isNullAt(0)) return
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
// 合并缓存:对多个buffer进行合并,这里的合并方式类似于reduce,新来的buffer都会和左侧合并后的大buffer进行合并,合并后保留大buffer的值,buffer2会被丢弃
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
// 计算结果:根据所有buffer合并后的值,计算最终的结果
// tips:这里所有buffer合并后对值为整体的sum和count,计算整体的sum和count比值,我们得到最终的平均值
override def evaluate(buffer: Row): Any =
buffer.getLong(0) / buffer.getLong(1)
不带类型的UDAF的使用
-
在sparkSQL中使用UDAF
-
在DataFrame中使用UDAF
def main(args: Array[String]): Unit =
val spark = SparkSession.builder().master("local").getOrCreate()
// 注册UDAF函数,和UDF函数一样
spark.udf.register("my_avg", avg)
// test.txt文件内容
// score|user
// 90|Tom
// 95|Jerry
// 100|Claris
// sparkSQL读取文件,创建视图
// sparkSQL的第一步:读取文件并创建视图
spark.read.option("header","true").option("sep","|").csv("test.txt").createOrReplaceTempView("v_user")
// sparkSQL的第二步:在spark.sql中调用UDAF,求分数的均值
spark.sql("select u_avg(score) as avg_score from v_user").show()
// DataFrame的第一步:读取文件,创建DataFrame
val df1 = spark.read.option("header","true").option("sep","|").csv("data/other/test.txt")
// DataFrame的第二步:在df.agg中,使用callUDF调用UDAF函数,求分数的均值
val df2 = df1.agg(callUDF("my_avg",col("score")))
df2.show(false)
带类型的UDAF【不常用】
- 继承Aggregator,继承时须在方括号内指定输入类型、缓存类型、输出类型
- 定义作为输入类型的User,作为缓存类型的Average,返回类型为Double
- 初始化缓存
- 更新缓存
- 合并缓存
- 计算结果
- 固定操作:定义缓存编码器(一般都是Encoders.product)、输出编码器
import org.apache.spark.sql.Encoder, Encoders, SparkSession
import org.apache.spark.sql.expressions.Aggregator, Window
case class Average(var sum: Long, var count: Long)
case class User(score: String, name:String)
// 继承Aggregator需要指定输入类型User、缓存类型Average、输出类型Double
object avg1 extends Aggregator[User,Average,Double]
// 初始化缓存:这里的缓存为一个Average实例,第一个0L代表sum,第二个0L代表count
override def zero: Average = Average(0L, 0L)
// 更新缓存:接受一个User类型,解析出需要的字段,进行累积计算
override def reduce(b: Average, a: User): Average =
b.sum += a.score.toLong
b.count += 1L
b
// 合并缓存:对多个缓存(Average对象)进行合并,所有右侧的Average会逐个合并到最左侧的Average,返回左侧的Average
override def merge(b1: Average, b2: Average): Average =
b1.sum += b2.sum
b1.count += b2.count
b1
// 计算结果:根据合并后的结果计算最终结果
override def finish(reduction: Average): Double =
reduction.sum.toDouble / reduction.count.toDouble
// 缓存编码器:注意左侧返回类型为Encoder[Average],只要是自定义类型,右侧一般都是Encoders.product
override def bufferEncoder: Encoder[Average] = Encoders.product
// 输出编码器:对输出进行编码,编码为java兼容的Double类型
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
带类型的UDAF的使用
- 在dataSet中结合select使用UDAF
def main(args: Array[String]): Unit =
val spark = SparkSession.builder().master("local").getOrCreate()
// test.txt文件内容
// score|user
// 90|Tom
// 95|Jerry
// 100|Claris
// DataSet的第一步:导入隐式转换,否则读取文件并调用as[U]时会报错
import spark.implicits._
// DataSet的第二步:读取文件,创建DataSet,这里由于读取的是csv文件,score字段默认为字符串类型,与User样例类中的类型保持一致,否则会报错String cannot cast to int
val df1 = spark.read.option("header","true").option("sep","|").csv("data/other/test.txt").as[User]
df1.show(false)
// DataFrame的第二步:在df.select中调用UDAF,求分数的均值
val df2 = df1.select(avg1.toColumn.name("test"))
df2.show(false)
以上是关于极简spark教程spark聚合函数的主要内容,如果未能解决你的问题,请参考以下文章
Dataframe Spark Scala中的最后一个聚合函数
Spark 系列—— Spark SQL 聚合函数 Aggregations