Spark sql实现自定义函数

Posted 郭朝阳@

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Spark sql实现自定义函数相关的知识,希望对你有一定的参考价值。

Spark sql实现自定义函数

文章目录


一、为什么要自定义function?

有小伙伴可能会疑问:Spark Sql提供了编写UDF和UDAF的接口扩展,为什么还有开发自定义函数呢?

虽然Spark SQL 提供了UDF和UDAF,但是当我们想要实现 原生函数一样的功能比如:语义参数 ,可变参数等 功能时候,UDF和UDAF就无法满足。

例如 我们想要实现类似于substr这样的函数, udf就无法实现, 其中的参数 ‘Spark SQL’ FROM 5、还有后面两个参数中最后一个可有可无的情况下。

> SELECT substr('Spark SQL', 5);
 k SQL
> SELECT substr('Spark SQL', -3);
 SQL
> SELECT substr('Spark SQL', 5, 1);
 k
> SELECT substr('Spark SQL' FROM 5);
 k SQL
> SELECT substr('Spark SQL' FROM -3);
 SQL
> SELECT substr('Spark SQL' FROM 5 FOR 1);
 k

``

二、实现自定义的函数

spark 官网提供了 SparkSessionExtensions类 ,可以自定义的增强和扩展Spark的很多能力,例如: injectOptimizerRuleinjectOptimizerRule等等。


举个例子吧。

为什么会有这样的需求呢?
原因是我想要解决Spark SQl 中的一些函数不完全满足我想要的功能。
比如:原生的spark Sql 函数to_timestamp 在执行有些参数的时候因为数据的格式和指定的parrten不匹配导致运行为null (严格模式下会报错)
我期望的结果应该为:2020-08-08 00:00:00,而不是为null, 简言之就是parrten只要是正确的时间格式,就应该解析出来。

这里是我们的需求,如果各位其他的需求 spark Sql 中的函数不是完全满足,通过UDF能实现,就用UDF实现,或者不完全满足 就跟我这个例子一样进行重写覆盖,如果完全没有 也可以按照这个逻辑自己定义一个全新的函数实现。

解决思路:
老套路,跟踪源码找到 报null和报错的代码逻辑,开发函数,重写逻辑,然后覆盖原函数。

问题代码如下:
1.ToTimestamp的eval方法

case StringType =>
          val fmt = right.eval(input)
          if (fmt == null) 
            null
           else 
            val formatter = formatterOption.getOrElse(getFormatter(fmt.toString))
            try 
              formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
             catch 
              case e if isParseError(e) =>
                if (failOnError) 
                  throw e
                 else 
                  null
                
            

可以看出解析失败 直接catch,根据failOnError 是否为严格模式报错还是返回null
2.ToTimestamp的doGenCode方法

 case StringType => formatterOption.map  fmt =>
        val df = classOf[TimestampFormatter].getName
        val formatterName = ctx.addReferenceObj("formatter", fmt, df)
        nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
          s"""
             |try 
             |  $ev.value = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor;
             | catch (java.time.DateTimeException e) 
             |  $parseErrorBranch
             | catch (java.time.format.DateTimeParseException e) 
             |  $parseErrorBranch
             | catch (java.text.ParseException e) 
             |  $parseErrorBranch
             |
             |""".stripMargin)
      

这里是拼接java代码的逻辑,逻辑和eval方法相同。

解决
1.开发逻辑
新建一个样例类继承ToTimestamp,重写上述的逻辑代码

解决思路: 当获取异常后,判断如果是应为格式问题解释失败,识别数据格式,将数据按照数据的格式解析成时间,然后再将时间类型的数据,解析成用户指定的字符串格式。详情看代码。

package v2.jdbc.spark.expressions.function

import java.text.ParseException
import java.time.format.DateTimeParseException
import java.time.DateTimeException, ZoneId
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator, CodegenContext, ExprCode
import org.apache.spark.sql.catalyst.expressions.Cast, Expression, TimeZoneAwareExpression, ToTimestamp
import org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMicros
import org.apache.spark.sql.catalyst.util.LegacyDateFormats, TimestampFormatter
import org.apache.spark.sql.catalyst.FunctionIdentifier, InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType, DateType, StringType, TimestampType
import org.apache.spark.unsafe.types.UTF8String
import v2.jdbc.spark.expressions.extra.ExpressionUtils, FunctionDescription
import v2.jdbc.spark.expressions.function.DateTimeUtils.dateStrChangeFormat

case class BiGetTimestamp(
                           left: Expression,
                           right: Expression,
                           timeZoneId: Option[String] = None,
                           failOnError: Boolean = SQLConf.get.ansiEnabled)extends ToTimestamp 
  override val downScaleFactor = 1

  override def dataType: DataType = TimestampType

  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
    copy(timeZoneId = Option(timeZoneId))

  private def isParseError(e: Throwable): Boolean = e match 
    case _: DateTimeParseException |
         _: DateTimeException |
         _: RuntimeException |
         _: ParseException => true
    case _ => false
  

  override def eval(input: InternalRow): Any = 

    val t = left.eval(input)
    if (t == null) 
      null
     else 
      left.dataType match 
        case DateType =>
          daysToMicros(t.asInstanceOf[Int], zoneId) / downScaleFactor
        case TimestampType =>
          t.asInstanceOf[Long] / downScaleFactor
        case StringType =>
          val fmt = right.eval(input)
          if (fmt == null) 
            null
           else 
            val formatter = formatterOption.getOrElse(getFormatter(fmt.toString))
            try 
              formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
             catch 
              case e if isParseError(e)=>
                val dateStr =UTF8String.fromString(dateStrChangeFormat(t.toString, fmt.toString)).toString
                formatter.parse(dateStr) / downScaleFactor
              case other=>
              if (failOnError) 
                  throw other
                 else 
                  null
                
            
          
      
    
  

 def doGenCodeErrorProcess(str1:String, datetimeStr: String,
                      ev: ExprCode,formatterName:String,pattern:String):String=
     s"""
       |
       |  boolean year = false;
       |  $str1
       |  if(pattern.matcher($datetimeStr.toString().substring(0, 4)).matches()) 
       |      year = true;
       |  
       |  StringBuilder sb = new StringBuilder();
       |  int index = 0;
       |  if(!year) 
       |      if($datetimeStr.toString().contains("月") || $datetimeStr.toString().contains("-") || $datetimeStr.toString().contains("/")) 
       |          if(Character.isDigit($datetimeStr.toString().charAt(0))) 
       |              index = 1;
       |          
       |      else 
       |          index = 3;
       |      
       |  
       |  for (int i = 0; i < $datetimeStr.toString().length(); i++) 
       |      char chr = $datetimeStr.toString().charAt(i);
       |      if(Character.isDigit(chr)) 
       |          if(index==0) 
       |              sb.append("y");
       |          
       |          if(index==1) 
       |              sb.append("M");
       |          
       |          if(index==2) 
       |              sb.append("d");
       |          
       |          if(index==3) 
       |              sb.append("H");
       |          
       |          if(index==4) 
       |              sb.append("m");
       |          
       |          if(index==5) 
       |              sb.append("s");
       |          
       |          if(index==6) 
       |              sb.append("S");
       |          
       |      else 
       |          if(i>0) 
       |              char lastChar = $datetimeStr.toString().charAt(i-1);
       |              if(Character.isDigit(lastChar)) 
       |                  index++;
       |              
       |          
       |          sb.append(chr);
       |      
       |  
       |  java.text.SimpleDateFormat simpleDateFormat = new java.text.SimpleDateFormat(sb.toString());
       |  java.util.Date date = simpleDateFormat.parse($datetimeStr.toString());
       |  java.text.SimpleDateFormat simpleDateFormat2 = new java.text.SimpleDateFormat("$pattern");
       |	$ev.value = $formatterName.parse(simpleDateFormat2.format(date)) / $downScaleFactor;
       |""".stripMargin

 

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
    val javaType = CodeGenerator.javaType(dataType)
    val parseErrorBranch = if (failOnError) "throw e;" else s"$ev.isNull = true;"
    val code = left.dataType match 
      case StringType => formatterOption.map  fmt =>
        val df = classOf[TimestampFormatter].getName
        val formatterName = ctx.addReferenceObj("formatter", fmt, df)
        val patternField = fmt.getClass.getDeclaredField("pattern")
        patternField.setAccessible(true)
        val pattern = patternField.get(fmt).toString
        val str1="java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(\\"^[-\\\\\\\\+]?[\\\\\\\\d]*$\\");"
        nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
          s"""
             |try 
             |  $ev.value = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor;
             | catch (java.time.DateTimeException e) 
             | $doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)
             | catch (java.time.format.DateTimeParseException e) 
             |  $doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)
             | catch (java.text.ParseException e) 
             |  $doGenCodeErrorProcess(str1,datetimeStr,ev,formatterName,pattern)
             |  catch (java.lang.RuntimeException e) 
             |  $doGenCodeErrorProcess(str1, datetimeStr, ev, formatterName, pattern)
             | catch (java.lang.Exception e) 
             |  $parseErrorBranch
             |
             |""".stripMargin)
      .getOrElse 
        val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
        val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
        val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$")
        val timestampFormatter = ctx.freshName("timestampFormatter")
        nullSafeCodeGen(ctx, ev, (string, format) =>
          s"""
             |$tf $timestampFormatter = $tf$$.MODULE$$.apply(
             |  $format.toString(),
             |  $zid,
             |  $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(),
             |  true);
             |try 
             |  $ev.value = $timestampFormatter.parse($string.toString()) / $downScaleFactor;
             | catch (java.time.format.DateTimeParseException e) 
             |    $parseErrorBranch
             | catch (java.time.DateTimeException e) 
             |    $parseErrorBranch
             | catch (java.text.ParseException e) 
             |    $parseErrorBranch
             |
             |""".stripMargin)
      
      case TimestampType =>
        val eval1 = left.genCode(ctx)
        ev.copy(code =
          code"""
          $eval1.code
          boolean $ev.isNull = $eval1.isNull;
          $javaType $ev.value = $CodeGenerator.defaultValue(dataType);
          if (!$ev.isNull) 
            $ev.value = $eval1.value / $downScaleFactor;
          """)
      case DateType =>
        val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
        val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
        val eval1 = left.genCode(ctx)
        ev.copy(code =
          code"""
          $eval1.code
          boolean $ev.isNull = $eval1.isNull;
          $javaType $ev.value = $CodeGenerator.defaultValue(dataType);
          if (!$ev.isNull) 
            $ev.value = $dtu.daysToMicros($eval1.value, $zid) / $downScaleFactor;
          """)
    
    code
  



object BiGetTimestamp 
    val fd: FunctionDescription = (
      new FunctionIdentifier("to_timestamp"),
      ExpressionUtils.getExpressionInfo(classOf[BiGetTimestamp], "to_timestamp"),
      (children: Seq[Expression]) =>
        children.size match 
          case 1=>
            Cast(children.head, TimestampType,Some("Asia/Shanghai"))
          case 2=>
            BiGetTimestamp(children.head,children(1),Some("Asia/Shanghai"))
          case _=>throw new Exception("参数异常")
        
       )

  val fd_toDte: FunctionDescription = (
    new FunctionIdentifier("to_date"),
    ExpressionUtils.getExpressionInfo(classOf[BiGetTimestamp], "to_date"),
    (children: Seq[Expression]) =>
      children.size match 
        case 1=>
          Cast(children.head, DateType,Some("Asia/Shanghai"))
        case 2=>
          Cast(BiGetTimestamp(children.head,children(1),Some("Asia/Shanghai")), DateType,Some("Asia/Shanghai"))
        case _=>throw new Exception("参数异常")
      
  )



package v2.jdbc.spark.expressions.function

import java.text.SimpleDateFormat
import java.util.regex.Pattern

object DateTimeUtils 
  
  /**
   * 识别日期字符串的日期格式
   */
  def identifyDateType(str: String): String = 
    var year = false
    val pattern = Pattern.compile("^[-\\\\+]?[\\\\d]*$")
    if (pattern.matcher(str.substring(0, 4)).matches) year = true
    val sb = new StringBuilder
    var index = 0
    if (!year) if (str.contains("月") || str.contains("-") || str.contains("/")) if (Character.isDigit(str.charAt(0))) index = 1
    else index = 3
    for (i <- 0 until str.length) 
      val chr = str.charAt(i)
      if (Character.isDigit(chr)) 
        if (index == 0) sb.append("y")
        if (index == 1) sb.append("M")
        if (index == 2) sb.append("d")
        if (index == 3) sb.append("H")
        if (index == 4) sb.append("m")
        if (index == 5) sb.append("s")
        if (index == 6) sb.append("S")
      
      else 
        if (i > 0) 
          val lastChar = str.charAt(i - 1)
          if (Character.isDigit(lastChar)) index += 1
        
        sb.append(chr)
      
    
    sb.toString
  


  def dateStrChangeFormat(dateStr: String,  targetFormat: String): String = 
    val sourceFormat = new SimpleDateFormat(identifyDateType(dateStr))
    val date = sourceFormat.parse(dateStr)
    val sourceFormat2 = new SimpleDateFormat(targetFormat)
    sourceFormat2.format(date)
  


2.注册函数
函数注册 有两种方式:
1.直接在构建Spark Session时候通过withExtensions直接使用。

2.不直接使用,通过SparkConf配置。
在配置参数中配置:
spark.sql.extensions=v2.jdbc.spark.expressions.extra.FunctionExtensions

三、测试效果

总结

这里是我们的需求改写并覆盖原有的函数,如果各位其他的需求 spark Sql 中的函数不是完全满足,通过UDF能实现,就用UDF实现,或者不完全满足 就跟我这个例子一样进行重写覆盖,如果完全没有 也可以按照这个逻辑 根据要实现函数的类型进行继承对应的Expression,编写eval和doGenCode方法, 自己定义一个全新的函数。

以上是关于Spark sql实现自定义函数的主要内容,如果未能解决你的问题,请参考以下文章

Spark sql实现自定义函数

Spark sql实现自定义函数

详解Spark sql用户自定义函数:UDF与UDAF

Spark SQL自定义函数

spark-sql 自定义函数

spark-sql自定义函数UDF和UDAF