如何在 spark sql 中对数组进行成员操作?

Posted

技术标签:

【中文标题】如何在 spark sql 中对数组进行成员操作?【英文标题】:How to do member-wise operations on array in spark sql? 【发布时间】:2018-10-16 19:39:12 【问题描述】:

在 spark-sql 中,我有一个包含 col 列的数据框,其中包含一个大小为 100 的 Int 数组(例如)。

我想将此列聚合为单个值,该值是一个大小为 100 的 Int 数组,其中包含该列中每个元素的总和。 可以通过调用来做到这一点:

dataframe.agg(functions.array((0 until 100).map(i => functions.sum(i)) : _*))

这将生成代码来明确地进行 100 次聚合,然后将 100 个结果显示为包含 100 个项目的数组。然而,这似乎非常低效,因为如果我的数组大小超过约 1000 个项目,催化剂甚至无法为此生成代码。 spark-sql 中是否有一个结构可以更有效地执行此操作? 理想情况下,应该可以在数组上自动传播sum 聚合以进行成员求和,但我在文档中没有找到与此相关的任何内容。 我的代码有哪些替代方案?

编辑:我的回溯:

   ERROR codegen.CodeGenerator: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
    at org.codehaus.janino.UnitCompiler.compileUnit(UnitCompiler.java:361)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:234)
    at org.codehaus.janino.SimpleCompiler.compileToClassLoader(SimpleCompiler.java:446)
    at org.codehaus.janino.ClassBodyEvaluator.compileToClass(ClassBodyEvaluator.java:313)
    at org.codehaus.janino.ClassBodyEvaluator.cook(ClassBodyEvaluator.java:235)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:204)
    at org.codehaus.commons.compiler.Cookable.cook(Cookable.java:80)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.org$apache$spark$sql$catalyst$expressions$codegen$CodeGenerator$$doCompile(CodeGenerator.scala:1002)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1069)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1066)
    at org.spark_project.guava.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
    at org.spark_project.guava.cache.LocalCache$Segment.loadSync(LocalCache.java:2379)
    at org.spark_project.guava.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342)
    at org.spark_project.guava.cache.LocalCache$Segment.get(LocalCache.java:2257)
    at org.spark_project.guava.cache.LocalCache.get(LocalCache.java:4000)
    at org.spark_project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004)
    at org.spark_project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:948)
    at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:375)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.prepareShuffleDependency(ShuffleExchange.scala:88)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:124)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:115)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.doExecute(ShuffleExchange.scala:115)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply$mcV$sp(FileFormatWriter.scala:173)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:145)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.DataSource.writeInFileFormat(DataSource.scala:435)
    at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:471)
    at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:609)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:233)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
    at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:597)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.runAndSaveMetrics(RankingMetricsComputer.scala:286)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.main(RankingMetricsComputer.scala:366)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer.main(RankingMetricsComputer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:635)

【问题讨论】:

对不起,我现在没有催化剂异常,但它与生成的代码太大/使用了太多变量有关 如果您有机会,请edit您的问题并附上回溯。这将有助于诊断问题并找到可能的解决方案。此外,如果您可以包含类型注释(dataframe - Dataset[_]RelationalGroupedDataset 是什么?),那就太好了。无论如何,就性能而言,您不会找到比聚合更好的解决方案。 太棒了。而dataframe的类型是什么? 相关SparkSQL job fails when calling stddev over 1,000 columns. 【参考方案1】:

最好的方法是将嵌套数组转换为它们自己的行,这样您就可以使用单个groupBy。这样,您可以在一个聚合中完成所有操作,而不是 100 个(或更多)。这样做的关键是使用posexplode,它将数组中的每个条目转换为一个新的行,其索引位于数组中。

例如:

import org.apache.spark.sql.functions.posexplode, collect_list

val data = Seq(
    (Seq(1, 2, 3, 4, 5)),
    (Seq(2, 3, 4, 5, 6)),
    (Seq(3, 4, 5, 6, 7))
)

val df = data.toDF

val df2 = df.
    select(posexplode($"value")).
    groupBy($"pos").
    agg(sum($"col") as "sum")

// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show

会像这样输出一个DataFrame:

+---+---+
|pos|sum|
+---+---+
|  0|  6|
|  1|  9|
|  2| 12|
|  3| 15|
|  4| 18|
+---+---+

或者如果你想让它们排成一行,你可以这样做广告:

df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show

不会对 Array 列中的值进行排序,但您可以编写一个 UDF 以按 pos 字段对其进行排序,如果您想这样做,可以删除 pos 字段。

根据评论更新

如果上述方法不适用于您尝试执行的任何其他聚合,那么您需要定义自己的 UDAF。这里的总体思路是告诉 Spark 如何组合分区内相同键的值以创建中间值,然后如何跨分区组合这些中间值以创建每个键的最终值。一旦你定义了一个 UDAF 类,你就可以在 aggs 调用中将它与你想做的任何其他聚合一起使用。

这是我淘汰的一个简单示例。请注意,它假定数组长度,并且可能应该更加防错,但应该让您大部分时间到达那里。

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction


class ArrayCombine extends UserDefinedAggregateFunction 
  // The input this aggregation will receive (each row)
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)

  // Your intermediate state as you are updating with data from each row
  override def bufferSchema: StructType = StructType(
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = ArrayType(IntegerType)

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = 
    buffer(0) = (0 until 100).toArray
  

  // Given a new input row, update our state
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 
    val sums = buffer.getSeq[Int](0)
    val newVals = input.getSeq[Int](0)

    buffer(0) = sums.zip(newVals).map  case (a, b) => a + b 
  

  // After we have finished computing intermediate values for each partition, combine the partitions
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
    val sums1 = buffer1.getSeq[Int](0)
    val sums2 = buffer2.getSeq[Int](0)

    buffer1(0) = sums1.zip(sums2).map  case (a, b) => a + b 
  

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = 
    buffer.getSeq[Int](0)
  

然后这样称呼它:

val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show

【讨论】:

UserDefinedAggregateFunction 如果@lezebulon 正在寻找更好的性能 - Spark UDAF with ArrayType as bufferSchema performance issues,尤其是大型阵列,那么像这样的建议是一个非常糟糕的建议。 链接方法更快,但如果操作需要同时运行多个聚合,我会先尝试 UDAF,看看它对他的数据的作用,然后再尝试更专业和不太灵活的方法。跨度> 非常感谢,我今天不能测试了,但我明天会报告我的结果 @user6910411 我不能将您链接的问题的答案和use primitive types in place of ArrayType 用于我的 UDAF 吗?如果是这样,在保持 UDAF 的同时是否应该更快? @lezebulon 您可能会尝试,但如果您遇到一些 Catalyst 问题,它可能不会改变任何事情,即使改变了,也不过是解决方法。出于好奇 - 当您跳过 array 时,您的代码是否有效?

以上是关于如何在 spark sql 中对数组进行成员操作?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 PySpark 中的大型 Spark 数据框中对行的每个子集进行映射操作

为啥 sortBy() 不能在 Spark 中对数据进行均匀排序?

在 Spark 数据帧内的映射中对结构数组进行聚合

如何在 SQL 中对数组进行 string_agg?

如何在pyspark中对数组中的标签进行编码

如何在pyspark中对数组中的标签进行编码