Spark UserDefinedAggregateFunction:scala.MatchError 0.0(类 java.lang.Double)

Posted

技术标签:

【中文标题】Spark UserDefinedAggregateFunction:scala.MatchError 0.0(类 java.lang.Double)【英文标题】:Spark UserDefinedAggregateFunction: scala.MatchError 0.0 (of class java.lang.Double) 【发布时间】:2016-11-21 09:47:41 【问题描述】:

我正在尝试在带有 Scala 的 Spark 2.0.2 上使用 UserDefinedAggregateFunction,但遇到了匹配错误。我已经创建了以下作为测试用例,我正在编写的代码与以下类似。

我正在尝试通过聚合窗口累积值。这不只是一个累积和,而是我需要根据一些条件计算出要保留的数量。

作为一个测试用例,我创建了一个摊销表,在这里我必须计算每个月的期初和期末余额。

数据如下:

+------+--------+------------+---------+
|Period| Capital|InterestRate|Repayment|
+------+--------+------------+---------+
|201601|   0.00 |       0.10 |    0.00 |
|201602|1000.00 |       0.00 |    0.00 |
|201603|2000.00 |       0.10 |    0.00 |
|201604|   0.00 |       0.10 | -200.00 |
|201605|   0.00 |       0.10 | -200.00 |
|201606|   0.00 |       0.10 | -200.00 |
|201607|   0.00 |       0.10 | -200.00 |
|201608|   0.00 |       0.00 | -200.00 |
|201609|   0.00 |       0.10 | -200.00 |
|201610|   0.00 |       0.10 | -200.00 |
|201611|   0.00 |       0.10 | -200.00 |
|201612|   0.00 |       0.10 | -200.00 |
+------+--------+------------+---------+

我无法正确格式化 CSV,但我已将其添加到此处的要点中:https://gist.github.com/nevi-me/8b2362a5365e73af947fc13bb5836adc。

我正在尝试计算 OpeningClosing 余额,然后从聚合中返回 Closing 余额。

斯卡拉

package me.nevi

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.MutableAggregationBuffer, UserDefinedAggregateFunction, Window
import org.apache.spark.sql.types.StructType, DoubleType, DataType

object AggregationTest 

  object amortisedClosingBalance extends UserDefinedAggregateFunction 
    override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType)

    override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)

    override def dataType: DataType = new StructType().add("Closing", DoubleType)

    override def deterministic: Boolean = true

    override def initialize(buffer: MutableAggregationBuffer): Unit = 
      buffer.update(0, 0.0)
      buffer.update(1, 0.0)
    

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 
      if (!input.isNullAt(0)) 
        println(buffer.get(0))
        println(buffer.get(1))
        buffer.update(0, buffer.getDouble(1))
        // (opening + capital) * interestrate - repayment
        buffer.update(1, (buffer.getDouble(0) + input.getDouble(0)) * input.getDouble(1) + input.getDouble(2))
       else 
        // if first record?
        buffer.update(0, input.getDouble(0))
        buffer.update(1, input.getDouble(0))
      
    

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
      buffer1.update(0, buffer1.getDouble(0))
      buffer1.update(1, buffer1.getDouble(1))
    

    override def evaluate(buffer: Row): Any = 
      buffer.getDouble(1)
    
  

  def main(args: Array[String]): Unit = 
    System.setProperty("hadoop.home.dir", "C:/spark")
    System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse")

    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("Aggregation Test")
      .getOrCreate()

    import spark.implicits._

    val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv")

    df.show()

    val windowSpec = Window.orderBy(df.col("Period"))

    val calc = df.withColumn("Closing", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec))

    calc.show()

  

我得到了例外:

scala.MatchError: 0.0 (of class java.lang.Double)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:256)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:251)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:403)
  at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:440)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
  at org.apache.spark.sql.execution.AggregateProcessor.evaluate(WindowExec.scala:1029)
  at org.apache.spark.sql.execution.UnboundedPrecedingWindowFunctionFrame.write(WindowExec.scala:822)
  at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:398)
  at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:289)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
  at org.apache.spark.scheduler.Task.run(Task.scala:86)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
  at java.lang.Thread.run(Thread.java:745)

有谁知道我做错了什么?我最初使用的是 Spark 2.0.0,但遇到其他人对 UDTF 有类似问题,建议升级到 2.0.1,但升级后;我的问题仍然存在。


解决方案:任何有兴趣的人

根据接受的答案,问题出在我的架构上。下面是计算正常的 sn-p。

package me.nevi

import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.MutableAggregationBuffer, UserDefinedAggregateFunction, Window
import org.apache.spark.sql.types.DataType, DoubleType, StructType

object AggregationTest 

  object amortisedClosingBalance extends UserDefinedAggregateFunction 
    override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType)

    override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)

    override def dataType: DataType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)

    override def deterministic: Boolean = true

    override def initialize(buffer: MutableAggregationBuffer): Unit = 
      buffer.update(0, 0.0)
      buffer.update(1, 0.0)
    

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 
      if (!input.isNullAt(0)) 
        println(buffer.get(0))
        println(buffer.get(1))
        buffer.update(0, buffer.getDouble(1))
        // (opening + capital) * interestrate - repayment
        buffer.update(1, input.getDouble(0)
          + buffer.getDouble(0) + input.getDouble(2) + (buffer.getDouble(0) + input.getDouble(0)) * (input.getDouble(1) / 12))
       else 
        // if first record?
        buffer.update(0, input.getDouble(0))
        buffer.update(1, input.getDouble(0))
      
    

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
      buffer1.update(0, buffer1.getDouble(0))
      buffer1.update(1, buffer1.getDouble(1))
    

    override def evaluate(buffer: Row): Any = 
      Row(buffer.getDouble(0), buffer.getDouble(1))
    
  

  def main(args: Array[String]): Unit = 
    System.setProperty("hadoop.home.dir", "C:/spark")
    System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse")

    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("Aggregation Test")
      .getOrCreate()

    import spark.implicits._

    val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv")

    df.show()

    val windowSpec = Window.orderBy(df.col("Period").asc)

    var calc = df.withColumn("Calcs", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec))
    calc = calc.withColumn("Opening", round($"Calcs".getField("Opening"), 2)).withColumn("Closing", round($"Calcs".getField("Closing"),2))
      .drop("Calcs")

    calc.show()

  

结果如下:

+------+--------+------------+---------+-------+-------+
|Period| Capital|InterestRate|Repayment|Opening|Closing|
+------+--------+------------+---------+-------+-------+
|201601|   0.00 |       0.10 |    0.00 |    0.0|    0.0|
|201602|1000.00 |       0.00 |    0.00 |    0.0| 1000.0|
|201603|2000.00 |       0.10 |    0.00 | 1000.0| 3025.0|
|201604|   0.00 |       0.10 | -200.00 | 3025.0|2850.21|
|201605|   0.00 |       0.10 | -200.00 |2850.21|2673.96|
|201606|   0.00 |       0.10 | -200.00 |2673.96|2496.24|
|201607|   0.00 |       0.10 | -200.00 |2496.24|2317.05|
|201608|   0.00 |       0.00 | -200.00 |2317.05|2117.05|
|201609|   0.00 |       0.10 | -200.00 |2117.05|1934.69|
|201610|   0.00 |       0.10 | -200.00 |1934.69|1750.81|
|201611|   0.00 |       0.10 | -200.00 |1750.81| 1565.4|
|201612|   0.00 |       0.10 | -200.00 | 1565.4|1378.44|
+------+--------+------------+---------+-------+-------+

【问题讨论】:

【参考方案1】:

由于dataType 定义不正确,您会遇到异常。您将其声明为:

StructType(StructField(Closing,DoubleType,true))

实际上你返回一个标量。应该定义为:

override def dataType: DataType = DoubleType

或者你应该重新定义evalute,例如:

override def evaluate(buffer: Row): Any = 
  Row(buffer.getDouble(1))

后者将返回一个嵌套列:

 |-- Closing: struct (nullable = true)
 |    |-- Closing: double (nullable = true)

所以它可能不是你要找的。​​p>

【讨论】:

谢谢,这对我有帮助。将来我会先检查我的架构。嵌套结构的后者更好,因为我可以同时返回期初余额和期末余额。我会用我最终做了什么来更新我的问题。

以上是关于Spark UserDefinedAggregateFunction:scala.MatchError 0.0(类 java.lang.Double)的主要内容,如果未能解决你的问题,请参考以下文章

spark提交参数解析

科普Spark,Spark是啥,如何使用Spark

Spark系列

Spark-01 spark简介

Spark 内核 Spark 内核解析-下

Spark官方文档: Spark Configuration(Spark配置)