如何在 Spark SQL 中定义和使用用户定义的聚合函数?
Posted
技术标签:
【中文标题】如何在 Spark SQL 中定义和使用用户定义的聚合函数?【英文标题】:How to define and use a User-Defined Aggregate Function in Spark SQL? 【发布时间】:2015-08-19 16:28:04 【问题描述】:我知道如何在 Spark SQL 中编写 UDF:
def belowThreshold(power: Int): Boolean =
return power < -40
sqlContext.udf.register("belowThreshold", belowThreshold _)
我可以做一些类似的事情来定义一个聚合函数吗?这是怎么做到的?
对于上下文,我想运行以下 SQL 查询:
val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp
FROM ifDF
WHERE opticalReceivePower IS NOT null
GROUP BY span, timestamp
ORDER BY span""")
它应该返回类似的东西
Row(span1, false, T0)
我希望聚合函数告诉我在span
和timestamp
定义的组中是否有任何低于阈值的opticalReceivePower
值。我的 UDAF 是否需要与上面粘贴的 UDF 不同?
【问题讨论】:
相关:***.com/questions/33899977/… 也许使用reduceByKey
/ foldByKey
作为recommended by zero323
查看文档对我的帮助远远超过答案或任何相关答案spark.apache.org/docs/2.4.0/sql-pyspark-pandas-with-arrow.html。答案是 Spark >= 2.3,但我在 2.4 时遇到了问题
【参考方案1】:
支持的方法
火花 >= 3.0
Scala UserDefinedAggregateFunction
已被弃用 (SPARK-30423 弃用 UserDefinedAggregateFunction) 以支持注册的 Aggregator
。
火花 >= 2.3
矢量化 udf(仅限 Python):
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.types import *
import pandas as pd
df = sc.parallelize([
("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])
def below_threshold(threshold, group="group", power="power"):
@pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
def below_threshold_(df):
df = pd.DataFrame(
df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
df.reset_index(inplace=True, drop=False)
return df
return below_threshold_
示例用法:
df.groupBy("group").apply(below_threshold(-40)).show()
## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## | b| true|
## | a| false|
## +-----+---------------+
另见Applying UDFs on GroupedData in PySpark (with functioning python example)
Spark >= 2.0(可选 1.6,但 API 略有不同):
可以在键入的Datasets
上使用Aggregators
:
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder, Encoders
class BelowThreshold[I](f: I => Boolean) extends Aggregator[I, Boolean, Boolean]
with Serializable
def zero = false
def reduce(acc: Boolean, x: I) = acc | f(x)
def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
def finish(acc: Boolean) = acc
def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)
火花 >= 1.5:
在 Spark 1.5 中,您可以像这样创建 UDAF,尽管这很可能是矫枉过正:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
object belowThreshold extends UserDefinedAggregateFunction
// Schema you get as an input
def inputSchema = new StructType().add("power", IntegerType)
// Schema of the row which is used for aggregation
def bufferSchema = new StructType().add("ind", BooleanType)
// Returned type
def dataType = BooleanType
// Self-explaining
def deterministic = true
// zero value
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
// Similar to seqOp in aggregate
def update(buffer: MutableAggregationBuffer, input: Row) =
if (!input.isNullAt(0))
buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
// Similar to combOp in aggregate
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) =
buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))
// Called on exit to get return value
def evaluate(buffer: Row) = buffer.getBoolean(0)
示例用法:
df
.groupBy($"group")
.agg(belowThreshold($"power").alias("belowThreshold"))
.show
// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// | a| false|
// | b| true|
// +-----+--------------+
Spark 1.4 解决方法:
我不确定我是否正确理解了您的要求,但据我所知,这里简单的旧聚合应该足够了:
val df = sc.parallelize(Seq(
("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")
df
.withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
.groupBy($"group")
.agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
.show
// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// | a| false|
// | b| true|
// +-----+--------------+
Spark :
据我所知,目前(Spark 1.4.1),除了 Hive 之外,不支持 UDAF。 Spark 1.5 应该可以实现(参见SPARK-3947)。
不支持/内部方法
在内部,Spark 使用了许多类,包括 ImperativeAggregates
和 DeclarativeAggregates
。
供内部使用,可能会更改,恕不另行通知,所以它可能不是你想在生产代码中使用的东西,但只是为了完整性BelowThreshold
和DeclarativeAggregate
可以这样实现(用Spark 2.2-SNAPSHOT):
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class BelowThreshold(child: Expression, threshold: Expression)
extends DeclarativeAggregate
override def children: Seq[Expression] = Seq(child, threshold)
override def nullable: Boolean = false
override def dataType: DataType = BooleanType
private lazy val belowThreshold = AttributeReference(
"belowThreshold", BooleanType, nullable = false
)()
// Used to derive schema
override lazy val aggBufferAttributes = belowThreshold :: Nil
override lazy val initialValues = Seq(
Literal(false)
)
override lazy val updateExpressions = Seq(Or(
belowThreshold,
If(IsNull(child), Literal(false), LessThan(child, threshold))
))
override lazy val mergeExpressions = Seq(
Or(belowThreshold.left, belowThreshold.right)
)
override lazy val evaluateExpression = belowThreshold
override def defaultResult: Option[Literal] = Option(Literal(false))
应该用withAggregateFunction
等价物进一步包裹它。
【讨论】:
从 Spark 2.0.1 开始,Aggregator
可与 groupBy
和 groupByKey
一起使用(请参阅 github.com/apache/spark/blob/master/sql/core/src/test/scala/org/…)。不幸的是,Aggregator
不适用于需要使用 UserDefinedAggregateFunction
的 Windows。【参考方案2】:
在 Spark(3.0+) Java 中定义和使用 UDF:
private static UDF1<Integer, Boolean> belowThreshold = (power) -> power < -40;
注册 UDF:
SparkSession.builder()
.appName(appName)
.master(master)
.getOrCreate().udf().register("belowThreshold", belowThreshold, BooleanType);
通过 Spark SQL 使用 UDF:
spark.sql("SELECT belowThreshold('50')");
【讨论】:
以上是关于如何在 Spark SQL 中定义和使用用户定义的聚合函数?的主要内容,如果未能解决你的问题,请参考以下文章