在 scala 中编写 udf 函数并在 pyspark 作业中使用它们

Posted

技术标签:

【中文标题】在 scala 中编写 udf 函数并在 pyspark 作业中使用它们【英文标题】:Writing udf function in scala and using them in pyspark job 【发布时间】:2018-12-10 06:58:05 【问题描述】:

我们正在尝试编写一个 scala udf 函数并从 pyspark 中的 map 函数调用它。 dateframe 架构非常复杂,我们要传递给此函数的列是 StructType 数组。

trip_force_speeds = trip_details.groupby("vehicle_id","driver_id", "StartDtLocal", "EndDtLocal")\ .agg(collect_list(struct(col("event_start_dt_local"), col("force"), col("speed"), col("sec_from_start"), col("sec_from_end"), col("StartDtLocal"), col("EndDtLocal"), col("verisk_vehicle_id"), col("trip_duration_sec")))\ .alias("trip_details"))

在我们的地图函数中,我们需要进行一些计算。

def calculateVariables(rec: Row):HashMap[String,Float] = 
val trips = rec.getAs[List]("trips")
val base_variables = new HashMap[String, Float]()   

val entropy_variables = new HashMap[String, Float]()

val week_day_list = List("monday", "tuesday", "wednesday", "thursday", "friday")

for (trip <- trips)

  if (trip("start_dt_local") >= trip("StartDtLocal") && trip("start_dt_local") <= trip("EndDtLocal"))
  
    base_variables("trip_summary_count") += 1

    if (trip("duration_sec").toFloat >= 300 && trip("duration_sec").toFloat <= 1800) 
      base_variables ("bounded_trip") +=  1

      base_variables("bounded_trip_duration") = trip("duration_sec") + base_variables("bounded_trip_duration")

      base_variables("total_bin_1") += 30

      base_variables("total_bin_2") += 30

      base_variables("total_bin_3") += 60

      base_variables("total_bin_5") += 60

      base_variables("total_bin_6") += 30

      base_variables("total_bin_7") += 30
    
    if (trip("duration_sec") > 120 && trip("duration_sec") < 21600 )
    
      base_variables("trip_count") += 1
    

    base_variables("trip_distance") += trip("distance_km")

    base_variables("trip_duration") = trip("duration_sec") + base_variables("trip_duration")

    base_variables("speed_event_distance") = trip("speed_event_distance_km")  + base_variables("speed_event_distance")

    base_variables("speed_event_duration") = trip("speed_event_duration_sec") + base_variables("speed_event_duration")

    base_variables("speed_event_distance_ratio") = trip("speed_distance_ratio") + base_variables("speed_event_distance_ratio")

    base_variables("speed_event_duration_ratio") = trip("speed_duration_ratio") + base_variables("speed_event_duration_ratio")

  

return base_variables

当我们尝试编译 scala 代码时出现错误

我尝试使用 Row 但收到此错误

"错误:类型实参(List)的种类不符合类型参数(类型T)的预期类型。List的类型参数与类型T的预期参数不匹配:类型List有一个类型参数,但类型T 没有——"

在我的情况下,行程是行列表。这是架构

StructType(List(StructField(verisk_vehicle_id,StringType,true),StructField(verisk_driver_id,StringType,false),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(trips,ArrayType(StructType(List(StructField(week_start_dt_local,TimestampType,true),StructField(week_end_dt_local,TimestampType,true),StructField(start_dt_local,TimestampType,true),StructField(end_dt_local,TimestampType,true),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(verisk_vehicle_id,StringType,true),StructField(duration_sec,FloatType,true),StructField(distance_km,FloatType,true),StructField(speed_distance_ratio,FloatType,true),StructField(speed_duration_ratio,FloatType,true),StructField(speed_event_distance_km,FloatType,true),StructField(speed_event_duration_sec,FloatType,true))),true),true),StructField(trip_details,ArrayType(StructType(List(StructField(event_start_dt_local,TimestampType,true),StructField(force,FloatType,true),StructField(speed,FloatType,true),StructField(sec_from_start,FloatType,true),StructField(sec_from_end,FloatType,true),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(verisk_vehicle_id,StringType,true),StructField(trip_duration_sec,FloatType,true))),true),true)))

我们定义函数签名的方式是否有问题,我们尝试覆盖 spark structtype,但这对我不起作用。

我来自 python 背景,在 python 工作中面临一些性能问题,这就是我决定在 Scala 中编写这个 map 函数的原因。

【问题讨论】:

【参考方案1】:

您必须在 udf 中使用 Row 类型而不是 StructType。 StructType 代表架构本身而不是数据。您可以使用的 Scala 中的一个小示例:

object test

  import org.apache.spark.sql.functions.udf, collect_list, struct

  val hash = HashMap[String, Float]("start_dt_local" -> 0)
  // This simple type to store you results
  val sampleDataset = Seq(Row(Instant.now().toEpochMilli, Instant.now().toEpochMilli))

  implicit val spark: SparkSession =
    SparkSession
      .builder()
      .appName("Test")
      .master("local[*]")
      .getOrCreate()

  def calculateVariablesUdf = udf  trip: Row =>

    if(trip.getAs[Long]("start_dt_local") >= trip.getAs[Long]("StartDtLocal")) 
      // crate a new instance with your results
      hash("start_dt_local") + 1
     else 
      hash("start_dt_local") + 0
    

  


  def main(args: Array[String]) : Unit = 

    Logger.getLogger("org").setLevel(Level.OFF)
    Logger.getLogger("akka").setLevel(Level.OFF)

    val rdd = spark.sparkContext.parallelize(sampleDataset)
    val df = spark.createDataFrame(rdd, StructType(List(StructField("start_dt_local", LongType, false), StructField("StartDtLocal", LongType, false))))

    df.agg(collect_list(calculateVariablesUdf(struct(col("start_dt_local"), col("StartDtLocal")))).as("result")).show(false)

  

编辑。为了更好地理解:

当您考虑架构描述时,您错了:将 StructType(List(StructField)) 作为字段的类型。您的 DataFrame 中没有 List 类型。

如果您将 calculateVariables 视为 udf,则不需要 for 循环。我的意思是:

def calculateVariables = udf  trip: Row =>
  trip("start_dt_local").getAs[Long] 
  // your logic ....


正如我在示例中所说,您可以直接在 udf 中返回更新后的哈希

【讨论】:

完成。您可以使用 java.sql.Timestamp 来处理您的时间戳类型。 List 的错误是因为 Scala 中的 List Type 是一种阻碍类型,所以您需要声明列表元素的类型:List[Int][ 或 List[String] 等 ...

以上是关于在 scala 中编写 udf 函数并在 pyspark 作业中使用它们的主要内容,如果未能解决你的问题,请参考以下文章

Hive UDF 在 Scala 中处理整数数组

使用 scala 在 spark sql 中编写 UDF

如何使用scala将特定函数转换为apache spark中的udf函数? [复制]

从 Scala 将 UDF 注册到 SqlContext 以在 PySpark 中使用

Scala udf UnsupportedOperationException

如何将火花行(StructType)投射到scala案例类