使用复杂类型查询 Spark SQL DataFrame

Posted

技术标签:

【中文标题】使用复杂类型查询 Spark SQL DataFrame【英文标题】:Querying Spark SQL DataFrame with complex types 【发布时间】:2015-02-04 22:12:38 【问题描述】:

如何查询具有复杂类型(例如地图/数组)的 RDD? 例如,当我在编写这个测试代码时:

case class Test(name: String, map: Map[String, String])
val map = Map("hello" -> "world", "hey" -> "there")
val map2 = Map("hello" -> "people", "hey" -> "you")
val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))

我认为语法应该是这样的:

sqlContext.sql("SELECT * FROM rdd WHERE map.hello = world")

sqlContext.sql("SELECT * FROM rdd WHERE map[hello] = world")

但我明白了

无法访问 MapType(StringType,StringType,true) 类型的嵌套字段

org.apache.spark.sql.catalyst.errors.package$TreeNodeException:未解决的属性

分别。

【问题讨论】:

我赞成接受的答案,它是查询复杂字段的所有方法的绝佳来源。作为那些这样做的人的快速参考:map[hello] 不起作用的原因是键是一个字符串字段,所以你必须引用它:map['hello'] 【参考方案1】:

这取决于列的类型。让我们从一些虚拟数据开始:

import org.apache.spark.sql.functions.udf, lit
import scala.util.Try

case class SubRecord(x: Int)
case class ArrayElement(foo: String, bar: Int, vals: Array[Double])
case class Record(
  an_array: Array[Int], a_map: Map[String, String], 
  a_struct: SubRecord, an_array_of_structs: Array[ArrayElement])


val df = sc.parallelize(Seq(
  Record(Array(1, 2, 3), Map("foo" -> "bar"), SubRecord(1),
         Array(
           ArrayElement("foo", 1, Array(1.0, 2.0, 2.0)),
           ArrayElement("bar", 2, Array(3.0, 4.0, 5.0)))),
  Record(Array(4, 5, 6), Map("foz" -> "baz"), SubRecord(2),
         Array(ArrayElement("foz", 3, Array(5.0, 6.0)), 
               ArrayElement("baz", 4, Array(7.0, 8.0))))
)).toDF
df.registerTempTable("df")
df.printSchema

// root
// |-- an_array: array (nullable = true)
// |    |-- element: integer (containsNull = false)
// |-- a_map: map (nullable = true)
// |    |-- key: string
// |    |-- value: string (valueContainsNull = true)
// |-- a_struct: struct (nullable = true)
// |    |-- x: integer (nullable = false)
// |-- an_array_of_structs: array (nullable = true)
// |    |-- element: struct (containsNull = true)
// |    |    |-- foo: string (nullable = true)
// |    |    |-- bar: integer (nullable = false)
// |    |    |-- vals: array (nullable = true)
// |    |    |    |-- element: double (containsNull = false)

数组 (ArrayType) 列:

Column.getItem方法

df.select($"an_array".getItem(1)).show

// +-----------+
// |an_array[1]|
// +-----------+
// |          2|
// |          5|
// +-----------+

Hive 括号语法:

sqlContext.sql("SELECT an_array[1] FROM df").show

// +---+
// |_c0|
// +---+
// |  2|
// |  5|
// +---+

一个UDF

val get_ith = udf((xs: Seq[Int], i: Int) => Try(xs(i)).toOption)

df.select(get_ith($"an_array", lit(1))).show

// +---------------+
// |UDF(an_array,1)|
// +---------------+
// |              2|
// |              5|
// +---------------+

除了上面列出的方法之外,Spark 还支持越来越多的对复杂类型进行操作的内置函数。值得注意的例子包括高阶函数,如 transform(SQL 2.4+、Scala 3.0+、PySpark / SparkR 3.1+):

df.selectExpr("transform(an_array, x -> x + 1) an_array_inc").show
// +------------+
// |an_array_inc|
// +------------+
// |   [2, 3, 4]|
// |   [5, 6, 7]|
// +------------+

import org.apache.spark.sql.functions.transform

df.select(transform($"an_array", x => x + 1) as "an_array_inc").show
// +------------+
// |an_array_inc|
// +------------+
// |   [2, 3, 4]|
// |   [5, 6, 7]|
// +------------+

filter(SQL 2.4+、Scala 3.0+、Python / SparkR 3.1+)

df.selectExpr("filter(an_array, x -> x % 2 == 0) an_array_even").show
// +-------------+
// |an_array_even|
// +-------------+
// |          [2]|
// |       [4, 6]|
// +-------------+

import org.apache.spark.sql.functions.filter

df.select(filter($"an_array", x => x % 2 === 0) as "an_array_even").show
// +-------------+
// |an_array_even|
// +-------------+
// |          [2]|
// |       [4, 6]|
// +-------------+

aggregate(SQL 2.4+、Scala 3.0+、PySpark / SparkR 3.1+):

df.selectExpr("aggregate(an_array, 0, (acc, x) -> acc + x, acc -> acc) an_array_sum").show
// +------------+
// |an_array_sum|
// +------------+
// |           6|
// |          15|
// +------------+

import org.apache.spark.sql.functions.aggregate

df.select(aggregate($"an_array", lit(0), (x, y) => x + y) as "an_array_sum").show
// +------------+                                                                  
// |an_array_sum|
// +------------+
// |           6|
// |          15|
// +------------+

数组处理函数(array_*),如array_distinct(2.4+):

import org.apache.spark.sql.functions.array_distinct

df.select(array_distinct($"an_array_of_structs.vals"(0))).show
// +-------------------------------------------+
// |array_distinct(an_array_of_structs.vals[0])|
// +-------------------------------------------+
// |                                 [1.0, 2.0]|
// |                                 [5.0, 6.0]|
// +-------------------------------------------+

array_max (array_min, 2.4+):

import org.apache.spark.sql.functions.array_max

df.select(array_max($"an_array")).show
// +-------------------+
// |array_max(an_array)|
// +-------------------+
// |                  3|
// |                  6|
// +-------------------+

flatten (2.4+)

import org.apache.spark.sql.functions.flatten

df.select(flatten($"an_array_of_structs.vals")).show
// +---------------------------------+
// |flatten(an_array_of_structs.vals)|
// +---------------------------------+
// |             [1.0, 2.0, 2.0, 3...|
// |             [5.0, 6.0, 7.0, 8.0]|
// +---------------------------------+

arrays_zip (2.4+):

import org.apache.spark.sql.functions.arrays_zip

df.select(arrays_zip($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show(false)
// +--------------------------------------------------------------------+
// |arrays_zip(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
// +--------------------------------------------------------------------+
// |[[1.0, 3.0], [2.0, 4.0], [2.0, 5.0]]                                |
// |[[5.0, 7.0], [6.0, 8.0]]                                            |
// +--------------------------------------------------------------------+

array_union (2.4+):

import org.apache.spark.sql.functions.array_union

df.select(array_union($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show
// +---------------------------------------------------------------------+
// |array_union(an_array_of_structs.vals[0], an_array_of_structs.vals[1])|
// +---------------------------------------------------------------------+
// |                                                 [1.0, 2.0, 3.0, 4...|
// |                                                 [5.0, 6.0, 7.0, 8.0]|
// +---------------------------------------------------------------------+

slice (2.4+):

import org.apache.spark.sql.functions.slice

df.select(slice($"an_array", 2, 2)).show
// +---------------------+
// |slice(an_array, 2, 2)|
// +---------------------+
// |               [2, 3]|
// |               [5, 6]|
// +---------------------+

映射 (MapType) 列

使用Column.getField方法:

df.select($"a_map".getField("foo")).show

// +----------+
// |a_map[foo]|
// +----------+
// |       bar|
// |      null|
// +----------+

使用 Hive 括号语法:

sqlContext.sql("SELECT a_map['foz'] FROM df").show

// +----+
// | _c0|
// +----+
// |null|
// | baz|
// +----+

使用带点语法的完整路径:

df.select($"a_map.foo").show

// +----+
// | foo|
// +----+
// | bar|
// |null|
// +----+

使用 UDF

val get_field = udf((kvs: Map[String, String], k: String) => kvs.get(k))

df.select(get_field($"a_map", lit("foo"))).show

// +--------------+
// |UDF(a_map,foo)|
// +--------------+
// |           bar|
// |          null|
// +--------------+

越来越多的map_* 函数,如map_keys (2.3+)

import org.apache.spark.sql.functions.map_keys

df.select(map_keys($"a_map")).show
// +---------------+
// |map_keys(a_map)|
// +---------------+
// |          [foo]|
// |          [foz]|
// +---------------+

map_values (2.3+)

import org.apache.spark.sql.functions.map_values

df.select(map_values($"a_map")).show
// +-----------------+
// |map_values(a_map)|
// +-----------------+
// |            [bar]|
// |            [baz]|
// +-----------------+

详情请查看SPARK-23899。

struct (StructType) 列使用带有点语法的完整路径:

使用 DataFrame API

df.select($"a_struct.x").show

// +---+
// |  x|
// +---+
// |  1|
// |  2|
// +---+

使用原始 SQL

sqlContext.sql("SELECT a_struct.x FROM df").show

// +---+
// |  x|
// +---+
// |  1|
// |  2|
// +---+

structs 数组中的字段可以使用点语法、名称和标准 Column 方法访问:

df.select($"an_array_of_structs.foo").show

// +----------+
// |       foo|
// +----------+
// |[foo, bar]|
// |[foz, baz]|
// +----------+

sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show

// +---+
// |_c0|
// +---+
// |foo|
// |foz|
// +---+

df.select($"an_array_of_structs.vals".getItem(1).getItem(1)).show

// +------------------------------+
// |an_array_of_structs.vals[1][1]|
// +------------------------------+
// |                           4.0|
// |                           8.0|
// +------------------------------+

可以使用 UDF 访问用户定义类型 (UDT) 字段。详情请见Spark SQL referencing attributes of UDT。

注意事项

根据 Spark 版本,其中一些方法仅适用于 HiveContext。 UDF 应该独立于使用标准 SQLContextHiveContext 的版本。

一般来说,嵌套值是二等公民。嵌套字段并不支持所有典型操作。根据上下文,展平架构和/或分解集合可能会更好

df.select(explode($"an_array_of_structs")).show

// +--------------------+
// |                 col|
// +--------------------+
// |[foo,1,WrappedArr...|
// |[bar,2,WrappedArr...|
// |[foz,3,WrappedArr...|
// |[baz,4,WrappedArr...|
// +--------------------+

点语法可以与通配符 (*) 组合来选择(可能是多个)字段,而无需明确指定名称:

df.select($"a_struct.*").show
// +---+
// |  x|
// +---+
// |  1|
// |  2|
// +---+

可以使用get_json_objectfrom_json 函数查询JSON 列。详情请见How to query JSON data column using Spark DataFrames?。

【讨论】:

是否可以获取结构数组中的所有元素?这样的事情可能吗.. sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show 如何使用代码而不是 spark sql 做与SELECT an_array_of_structs[0].foo FROM df 相同的事情?是否支持使用代码在结构列数组(an_array_of_structs)上执行 UDF?喜欢SELECT max(an_array_of_structs.bar) FROM df 使用代码。 哇。很好的开放答案。非常感谢。 哇^10 惊人的答案! 尝试导入 org.apache.spark.sql.functions.transform 时出现错误。所有其他导入似乎都有效,知道为什么会发生这种情况吗?【参考方案2】:

将其转换为 DF 后,您可以简单地获取数据为

  val rddRow= rdd.map(kv=>
    val k = kv._1
    val v = kv._2
    Row(k, v)
  )

val myFld1 =  StructField("name", org.apache.spark.sql.types.StringType, true)
val myFld2 =  StructField("map", org.apache.spark.sql.types.MapType(StringType, StringType), true)
val arr = Array( myFld1, myFld2)
val schema = StructType( arr )
val rowrddDF = sqc.createDataFrame(rddRow, schema)
rowrddDF.registerTempTable("rowtbl")  
val rowrddDFFinal = rowrddDF.select(rowrddDF("map.one"))
or
val rowrddDFFinal = rowrddDF.select("map.one")

【讨论】:

当我尝试这个时,我得到error: value _1 is not a member of org.apache.spark.sql.Row【参考方案3】:

这就是我所做的,它奏效了

case class Test(name: String, m: Map[String, String])
val map = Map("hello" -> "world", "hey" -> "there")
val map2 = Map("hello" -> "people", "hey" -> "you")
val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2)))
val rdddf = rdd.toDF
rdddf.registerTempTable("mytable")
sqlContext.sql("select m.hello from mytable").show

结果

+------+
| hello|
+------+
| world|
|people|
+------+

【讨论】:

以上是关于使用复杂类型查询 Spark SQL DataFrame的主要内容,如果未能解决你的问题,请参考以下文章

Spark SQL 查询中的高阶函数

在镶木地板的地图类型列上使用 spark-sql 过滤下推

无法使用 pyspark 从 hive 表中查询复杂的 SQL 语句

Spark SQL分区感知查询hive表

一条 SQL 在 Apache Spark 之旅(上)

在 Spark SQL 中将 long 类型的列转换为 calendarinterval 类型