如何在 spark sql 中对数组进行成员操作?
Posted
技术标签:
【中文标题】如何在 spark sql 中对数组进行成员操作?【英文标题】:How to do member-wise operations on array in spark sql? 【发布时间】:2018-10-16 19:39:12 【问题描述】:在 spark-sql 中,我有一个包含 col
列的数据框,其中包含一个大小为 100 的 Int 数组(例如)。
我想将此列聚合为单个值,该值是一个大小为 100 的 Int 数组,其中包含该列中每个元素的总和。 可以通过调用来做到这一点:
dataframe.agg(functions.array((0 until 100).map(i => functions.sum(i)) : _*))
这将生成代码来明确地进行 100 次聚合,然后将 100 个结果显示为包含 100 个项目的数组。然而,这似乎非常低效,因为如果我的数组大小超过约 1000 个项目,催化剂甚至无法为此生成代码。
spark-sql 中是否有一个结构可以更有效地执行此操作?
理想情况下,应该可以在数组上自动传播sum
聚合以进行成员求和,但我在文档中没有找到与此相关的任何内容。
我的代码有哪些替代方案?
编辑:我的回溯:
ERROR codegen.CodeGenerator: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
at org.codehaus.janino.UnitCompiler.compileUnit(UnitCompiler.java:361)
at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:234)
at org.codehaus.janino.SimpleCompiler.compileToClassLoader(SimpleCompiler.java:446)
at org.codehaus.janino.ClassBodyEvaluator.compileToClass(ClassBodyEvaluator.java:313)
at org.codehaus.janino.ClassBodyEvaluator.cook(ClassBodyEvaluator.java:235)
at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:204)
at org.codehaus.commons.compiler.Cookable.cook(Cookable.java:80)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.org$apache$spark$sql$catalyst$expressions$codegen$CodeGenerator$$doCompile(CodeGenerator.scala:1002)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1069)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1066)
at org.spark_project.guava.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
at org.spark_project.guava.cache.LocalCache$Segment.loadSync(LocalCache.java:2379)
at org.spark_project.guava.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342)
at org.spark_project.guava.cache.LocalCache$Segment.get(LocalCache.java:2257)
at org.spark_project.guava.cache.LocalCache.get(LocalCache.java:4000)
at org.spark_project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004)
at org.spark_project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874)
at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:948)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:375)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.exchange.ShuffleExchange.prepareShuffleDependency(ShuffleExchange.scala:88)
at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:124)
at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:115)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
at org.apache.spark.sql.execution.exchange.ShuffleExchange.doExecute(ShuffleExchange.scala:115)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply$mcV$sp(FileFormatWriter.scala:173)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:166)
at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:145)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.execution.datasources.DataSource.writeInFileFormat(DataSource.scala:435)
at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:471)
at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:609)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:233)
at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:597)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.runAndSaveMetrics(RankingMetricsComputer.scala:286)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.main(RankingMetricsComputer.scala:366)
at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer.main(RankingMetricsComputer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:635)
【问题讨论】:
对不起,我现在没有催化剂异常,但它与生成的代码太大/使用了太多变量有关 如果您有机会,请edit您的问题并附上回溯。这将有助于诊断问题并找到可能的解决方案。此外,如果您可以包含类型注释(dataframe
- Dataset[_]
、RelationalGroupedDataset
是什么?),那就太好了。无论如何,就性能而言,您不会找到比聚合更好的解决方案。
太棒了。而dataframe
的类型是什么?
相关SparkSQL job fails when calling stddev over 1,000 columns.
【参考方案1】:
最好的方法是将嵌套数组转换为它们自己的行,这样您就可以使用单个groupBy
。这样,您可以在一个聚合中完成所有操作,而不是 100 个(或更多)。这样做的关键是使用posexplode
,它将数组中的每个条目转换为一个新的行,其索引位于数组中。
例如:
import org.apache.spark.sql.functions.posexplode, collect_list
val data = Seq(
(Seq(1, 2, 3, 4, 5)),
(Seq(2, 3, 4, 5, 6)),
(Seq(3, 4, 5, 6, 7))
)
val df = data.toDF
val df2 = df.
select(posexplode($"value")).
groupBy($"pos").
agg(sum($"col") as "sum")
// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show
会像这样输出一个DataFrame:
+---+---+
|pos|sum|
+---+---+
| 0| 6|
| 1| 9|
| 2| 12|
| 3| 15|
| 4| 18|
+---+---+
或者如果你想让它们排成一行,你可以这样做广告:
df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show
不会对 Array 列中的值进行排序,但您可以编写一个 UDF 以按 pos 字段对其进行排序,如果您想这样做,可以删除 pos 字段。
根据评论更新
如果上述方法不适用于您尝试执行的任何其他聚合,那么您需要定义自己的 UDAF。这里的总体思路是告诉 Spark 如何组合分区内相同键的值以创建中间值,然后如何跨分区组合这些中间值以创建每个键的最终值。一旦你定义了一个 UDAF 类,你就可以在 aggs
调用中将它与你想做的任何其他聚合一起使用。
这是我淘汰的一个简单示例。请注意,它假定数组长度,并且可能应该更加防错,但应该让您大部分时间到达那里。
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
class ArrayCombine extends UserDefinedAggregateFunction
// The input this aggregation will receive (each row)
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
// Your intermediate state as you are updating with data from each row
override def bufferSchema: StructType = StructType(
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
)
// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(IntegerType)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit =
buffer(0) = (0 until 100).toArray
// Given a new input row, update our state
override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
val sums = buffer.getSeq[Int](0)
val newVals = input.getSeq[Int](0)
buffer(0) = sums.zip(newVals).map case (a, b) => a + b
// After we have finished computing intermediate values for each partition, combine the partitions
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
val sums1 = buffer1.getSeq[Int](0)
val sums2 = buffer2.getSeq[Int](0)
buffer1(0) = sums1.zip(sums2).map case (a, b) => a + b
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any =
buffer.getSeq[Int](0)
然后这样称呼它:
val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show
【讨论】:
UserDefinedAggregateFunction
如果@lezebulon 正在寻找更好的性能 - Spark UDAF with ArrayType as bufferSchema performance issues,尤其是大型阵列,那么像这样的建议是一个非常糟糕的建议。
链接方法更快,但如果操作需要同时运行多个聚合,我会先尝试 UDAF,看看它对他的数据的作用,然后再尝试更专业和不太灵活的方法。跨度>
非常感谢,我今天不能测试了,但我明天会报告我的结果
@user6910411 我不能将您链接的问题的答案和use primitive types in place of ArrayType
用于我的 UDAF 吗?如果是这样,在保持 UDAF 的同时是否应该更快?
@lezebulon 您可能会尝试,但如果您遇到一些 Catalyst 问题,它可能不会改变任何事情,即使改变了,也不过是解决方法。出于好奇 - 当您跳过 array
时,您的代码是否有效?以上是关于如何在 spark sql 中对数组进行成员操作?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 PySpark 中的大型 Spark 数据框中对行的每个子集进行映射操作