Spark SQL 添加列/更新累加值

Posted

技术标签:

【中文标题】Spark SQL 添加列/更新累加值【英文标题】:Spark SQL add column/update-accumulate value 【发布时间】:2018-11-09 14:05:22 【问题描述】:

我有以下数据框:

name,email,phone,country
------------------------------------------------
[Mike,mike@example.com,+91-9999999999,Italy]
[Alex,alex@example.com,+91-9999999998,France]
[John,john@example.com,+1-1111111111,United States]
[Donald,donald@example.com,+1-2222222222,United States]
[Dan,dan@example.com,+91-9999444999,Poland]
[Scott,scott@example.com,+91-9111999998,Spain]
[Rob,rob@example.com,+91-9114444998,Italy]

暴露为临时表tagged_users:

resultDf.createOrReplaceTempView("tagged_users")

我需要在这个DataFrame中添加额外的列tag,并通过不同的SQL条件分配计算的标签,这些在下面的地图中描述(键-标签名称,值-WHERE子句的条件)

val tags = Map(
  "big" -> "country IN (SELECT * FROM big_countries)",
  "medium" -> "country IN (SELECT * FROM medium_countries)",
  //2000 other different tags and conditions
  "sometag" -> "name = 'Donald' AND email = 'donald@example.com' AND phone = '+1-2222222222'"
  )

为了能够在 SQL 查询中使用它们,我有以下 DataFrame(作为数据字典):

Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

我想测试我的tagged_users 表中的每一行并为其分配适当的标签。为了实现它,我尝试实现以下逻辑:

tags.foreach 
  case (tag, tagCondition) => 
    resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
       .withColumn("tag", lit(tag).cast(StringType))
  


def buildTagQuery(tag: String, tagCondition: String, table: String): String = 
    f"SELECT * FROM $table WHERE $tagCondition"

但现在我不知道如何累积标签而不是覆盖它们。现在结果我有以下DataFrame:

name,email,phone,country,tag
Dan,dan@example.com,+91-9999444999,Poland,medium
Scott,scott@example.com,+91-9111999998,Spain,medium

但我需要类似的东西:

name,email,phone,country,tag
Mike,mike@example.com,+91-9999999999,Italy,big
Alex,alex@example.com,+91-9999999998,France,big
John,john@example.com,+1-1111111111,United States,big
Donald,donald@example.com,+1-2222222222,United States,(big|sometag)
Dan,dan@example.com,+91-9999444999,Poland,medium
Scott,scott@example.com,+91-9111999998,Spain,(big|medium)
Rob,rob@example.com,+91-9114444998,Italy,big

请注意Donal应该有2个标签(big|sometag)Scott应该有2个标签(big|medium)

请说明如何实现它。

更新

val spark = SparkSession
  .builder()
  .appName("Java Spark SQL basic example")
  .config("spark.master", "local")
  .getOrCreate();

import spark.implicits._
import spark.sql

Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

val df = Seq(
  ("Mike", "mike@example.com", "+91-9999999999", "Italy"),
  ("Alex", "alex@example.com", "+91-9999999998", "France"),
  ("John", "john@example.com", "+1-1111111111", "United States"),
  ("Donald", "donald@example.com", "+1-2222222222", "United States"),
  ("Dan", "dan@example.com", "+91-9999444999", "Poland"),
  ("Scott", "scott@example.com", "+91-9111999998", "Spain"),
  ("Rob", "rob@example.com", "+91-9114444998", "Italy")).toDF("name", "email", "phone", "country")

df.collect.foreach(println)

df.createOrReplaceTempView("tagged_users")

val tags = Map(
  "big" -> "country IN (SELECT * FROM big_countries)",
  "medium" -> "country IN (SELECT * FROM medium_countries)",
  "sometag" -> "name = 'Donald' AND email = 'donald@example.com' AND phone = '+1-2222222222'")

val sep_tag = tags.map((x) =>  s"when array_contains(" + x._1 + ", country) then '" + x._1 + "' " ).mkString

val combine_sel_tag1 = tags.map((x) =>  s" array_contains(" + x._1 + ",country) " ).mkString(" and ")

val combine_sel_tag2 = tags.map((x) => x._1).mkString(" '(", "|", ")' ")

val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "

val crosqry = tags.map((x) =>  s" cross join ( select collect_list(country) as " + x._1 + " from " + x._1 + "_countries) " + x._1 + "  " ).mkString

val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry

spark.sql(qry).show

spark.stop()

失败并出现以下异常:

Caused by: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'sometag_countries' not found in database 'default';
    at org.apache.spark.sql.catalyst.catalog.ExternalCatalog$class.requireTableExists(ExternalCatalog.scala:48)
    at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.requireTableExists(InMemoryCatalog.scala:45)
    at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.getTable(InMemoryCatalog.scala:326)
    at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.getTable(ExternalCatalogWithListener.scala:138)
    at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupRelation(SessionCatalog.scala:701)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveRelations$$lookupTableFromCatalog(Analyzer.scala:730)
    ... 74 more

【问题讨论】:

为什么没有一个包含两列的国家表,国家名称和描述?描述可以是来自(小、中、大、大|中、小|中)的单个值。然后你只需要加入国家名称的两个表。 因为这只是一个特例。根据系统设计,用户可以通过 UI 配置任意数量的标签,使用不同的条件和名称 big_countriesmedium_countries 等集合也是如此。用户可以通过 UI 配置任意数量的具有不同名称和元素的集合,并在其 SQL 查询中使用对它们的引用 【参考方案1】:

查看这个 DF 解决方案:

scala> val df = Seq(("Mike","mike@example.com","+91-9999999999","Italy"),
     | ("Alex","alex@example.com","+91-9999999998","France"),
     | ("John","john@example.com","+1-1111111111","United States"),
     | ("Donald","donald@example.com","+1-2222222222","United States"),
     | ("Dan","dan@example.com","+91-9999444999","Poland"),
     | ("Scott","scott@example.com","+91-9111999998","Spain"),
     | ("Rob","rob@example.com","+91-9114444998","Italy")
     | ).toDF("name","email","phone","country")
df: org.apache.spark.sql.DataFrame = [name: string, email: string ... 2 more fields]

scala> val dfbc=Seq("Italy", "France", "United States", "Spain").toDF("country")
dfbc: org.apache.spark.sql.DataFrame = [country: string]

scala> val dfmc=Seq("Poland", "Hungary", "Spain").toDF("country")
dfmc: org.apache.spark.sql.DataFrame = [country: string]

scala> val dfbc2=dfbc.agg(collect_list('country).as("bcountry"))
dfbc2: org.apache.spark.sql.DataFrame = [bcountry: array<string>]

scala> val dfmc2=dfmc.agg(collect_list('country).as("mcountry"))
dfmc2: org.apache.spark.sql.DataFrame = [mcountry: array<string>]

scala> val df2=df.crossJoin(dfbc2).crossJoin(dfmc2)
df2: org.apache.spark.sql.DataFrame = [name: string, email: string ... 4 more fields]

scala> df2.selectExpr("*","case when array_contains(bcountry,country) and array_contains(mcountry,country) then '(big|medium)' when array_contains(bcountry,country) then 'big' when array_contains(mcountry,country) then 'medium' else 'none' end as `tags`").select("name","email","phone","country","tags").show(false)
+------+------------------+--------------+-------------+------------+
|name  |email             |phone         |country      |tags        |
+------+------------------+--------------+-------------+------------+
|Mike  |mike@example.com  |+91-9999999999|Italy        |big         |
|Alex  |alex@example.com  |+91-9999999998|France       |big         |
|John  |john@example.com  |+1-1111111111 |United States|big         |
|Donald|donald@example.com|+1-2222222222 |United States|big         |
|Dan   |dan@example.com   |+91-9999444999|Poland       |medium      |
|Scott |scott@example.com |+91-9111999998|Spain        |(big|medium)|
|Rob   |rob@example.com   |+91-9114444998|Italy        |big         |
+------+------------------+--------------+-------------+------------+


scala>

SQL 方法

scala> Seq(("Mike","mike@example.com","+91-9999999999","Italy"),
     |       ("Alex","alex@example.com","+91-9999999998","France"),
     |       ("John","john@example.com","+1-1111111111","United States"),
     |       ("Donald","donald@example.com","+1-2222222222","United States"),
     |       ("Dan","dan@example.com","+91-9999444999","Poland"),
     |       ("Scott","scott@example.com","+91-9111999998","Spain"),
     |       ("Rob","rob@example.com","+91-9114444998","Italy")
     |       ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")

scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")

scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

scala> spark.sql(""" select name,email,phone,country,case when array_contains(bc,country) and array_contains(mc,country) then '(big|medium)' when array_contains(bc,country) then 'big' when array_contains(mc,country) then 'medium' else 'none' end as tags from tagged_users cross join ( select collect_list(country) as bc from big_countries ) b cross join ( select collect_list(country) as mc from medium_countries ) c """).show(false)
+------+------------------+--------------+-------------+------------+
|name  |email             |phone         |country      |tags        |
+------+------------------+--------------+-------------+------------+
|Mike  |mike@example.com  |+91-9999999999|Italy        |big         |
|Alex  |alex@example.com  |+91-9999999998|France       |big         |
|John  |john@example.com  |+1-1111111111 |United States|big         |
|Donald|donald@example.com|+1-2222222222 |United States|big         |
|Dan   |dan@example.com   |+91-9999444999|Poland       |medium      |
|Scott |scott@example.com |+91-9111999998|Spain        |(big|medium)|
|Rob   |rob@example.com   |+91-9114444998|Italy        |big         |
+------+------------------+--------------+-------------+------------+


scala>

遍历标签

scala> val tags = Map(
     |   "big" -> "country IN (SELECT * FROM big_countries)",
     |   "medium" -> "country IN (SELECT * FROM medium_countries)")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries))

scala> val sep_tag = tags.map( (x) =>  s"when array_contains("+x._1+", country) then '" + x._1 + "' "  ).mkString
sep_tag: String = "when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' "

scala> val combine_sel_tag1 = tags.map( (x) =>  s" array_contains("+x._1+",country) "  ).mkString(" and ")
combine_sel_tag1: String = " array_contains(big,country)  and  array_contains(medium,country) "

scala> val combine_sel_tag2 = tags.map( (x) => x._1 ).mkString(" '(","|", ")' ")
combine_sel_tag2: String = " '(big|medium)' "

scala> val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 +  sep_tag + " end as tags "
combine_sel_all: String = " case when  array_contains(big,country)  and  array_contains(medium,country)  then  '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium'  end as tags "

scala> val crosqry = tags.map( (x) =>  s" cross join ( select collect_list(country) as "+x._1+" from "+x._1+"_countries) "+ x._1 + "  "  ).mkString
crosqry: String = " cross join ( select collect_list(country) as big from big_countries) big   cross join ( select collect_list(country) as medium from medium_countries) medium  "

scala> val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
qry: String = " select name,email,phone,country,  case when  array_contains(big,country)  and  array_contains(medium,country)  then  '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium'  end as tags  from tagged_users  cross join ( select collect_list(country) as big from big_countries) big   cross join ( select collect_list(country) as medium from medium_countries) medium  "

scala> spark.sql(qry).show
+------+------------------+--------------+-------------+------------+
|  name|             email|         phone|      country|        tags|
+------+------------------+--------------+-------------+------------+
|  Mike|  mike@example.com|+91-9999999999|        Italy|         big|
|  Alex|  alex@example.com|+91-9999999998|       France|         big|
|  John|  john@example.com| +1-1111111111|United States|         big|
|Donald|donald@example.com| +1-2222222222|United States|         big|
|   Dan|   dan@example.com|+91-9999444999|       Poland|      medium|
| Scott| scott@example.com|+91-9111999998|        Spain|(big|medium)|
|   Rob|   rob@example.com|+91-9114444998|        Italy|         big|
+------+------------------+--------------+-------------+------------+


scala>

更新 2:

scala> Seq(("Mike","mike@example.com","+91-9999999999","Italy"),
     | ("Alex","alex@example.com","+91-9999999998","France"),
     | ("John","john@example.com","+1-1111111111","United States"),
     | ("Donald","donald@example.com","+1-2222222222","United States"),
     | ("Dan","dan@example.com","+91-9999444999","Poland"),
     | ("Scott","scott@example.com","+91-9111999998","Spain"),
     | ("Rob","rob@example.com","+91-9114444998","Italy")
     | ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")

scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")

scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

scala> val tags = Map(
     |   "big" -> "country IN (SELECT * FROM big_countries)",
     |   "medium" -> "country IN (SELECT * FROM medium_countries)",
     |   "sometag" -> "name = 'Donald' AND email = 'donald@example.com' AND phone = '+1-2222222222'")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries), sometag -> name = 'Donald' AND email = 'donald@example.com' AND phone = '+1-2222222222')

scala> val sql_tags = tags.map( x =>  val p = x._2.trim.toUpperCase.split(" ");
     | val qry = if(p.contains("IN") && p.contains("FROM"))
     | s" case when array_contains((select collect_list("+p.head +") from " + p.last.replaceAll("[)]","")+ " ), " +p.head + " ) then '" + x._1 + " ' else '' end " + x._1 + " "
     | else
     | " case when " + x._2 + " then '" + x._1 + " ' else '' end " + x._1 + " ";
     | qry  ).mkString(",")
sql_tags: String = " case when array_contains((select collect_list(COUNTRY) from BIG_COUNTRIES ), COUNTRY ) then 'big ' else '' end big , case when array_contains((select collect_list(COUNTRY) from MEDIUM_COUNTRIES ), COUNTRY ) then 'medium ' else '' end medium , case when name = 'Donald' AND email = 'donald@example.com' AND phone = '+1-2222222222' then 'sometag ' else '' end sometag "

scala> val outer_query = tags.map( x=> x._1).mkString(" regexp_replace(trim(concat(", ",", " )),' ','|') tags ")
outer_query: String = " regexp_replace(trim(concat(big,medium,sometag )),' ','|') tags "

scala> spark.sql(" select name,email, country, " + outer_query + " from ( select name,email, country ," + sql_tags + "   from tagged_users ) " ).show
+------+------------------+-------------+-----------+
|  name|             email|      country|       tags|
+------+------------------+-------------+-----------+
|  Mike|  mike@example.com|        Italy|        big|
|  Alex|  alex@example.com|       France|        big|
|  John|  john@example.com|United States|        big|
|Donald|donald@example.com|United States|big|sometag|
|   Dan|   dan@example.com|       Poland|     medium|
| Scott| scott@example.com|        Spain| big|medium|
|   Rob|   rob@example.com|        Italy|        big|
+------+------------------+-------------+-----------+


scala>

【讨论】:

谢谢,但我可以在地图中有 200 或 1000 个不同的标签,所以我认为 case when array_contains(bcountry,country) 的硬编码将不起作用。另外,根据系统设计,我仅限于使用 SQL 查询而不是 Spark API 来处理 WHERE 条件 谢谢,但我不能在 SQL 查询中硬编码标签集。标签集合必须在遍历地图和评估标签条件的过程中累积 @alexanoid.. 我刚刚更新,假设我们有标签表名称为 ""_countries 并构造了 qry 字符串。希望这会有所帮助。 非常感谢。对不起,我之前可能不清楚。我用新的tags 地图更新了我的问题。标签条件可能非常复杂,并且彼此之间可能存在显着差异。它可以基于多个列和公式进行条件...由用户指定他/她想要的条件。考虑到这一点,您的方法在这种情况下是否有效? 是的..如果你看到了,标签不会在任何地方硬编码..动态驱动并且查询是逐步构建到最终查询中的。【参考方案2】:

如果您需要聚合结果而不只是执行每个查询,可以使用 map 而不是 foreach 然后合并结果

 val o = tags.map 
  case (tag, tagCondition) => 
    val resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
      .withColumn("tag", new Column("blah"))
    resultDf
  


o.foldLeft(o.head) 
  case (acc, df) => acc.union(df)

【讨论】:

感谢您的回答!您能否评论一下,从性能的角度来看,这种方法将如何发挥作用?【参考方案3】:

我会定义多个带有列值标签的标签表。

那么您的标签定义将是一个集合,例如 Seq[(String, String],其中第一个元组元素是计算标签的列。

让我们说

Seq(
  "country" -> "bigCountries", // Columns [country, bigCountry]
  "country" -> "mediumCountries", // Columns [country, mediumCountry]
  "email" -> "hotmailLosers" // [country, hotmailLoser]
)

然后遍历此列表,将相关列上的每个表与关联列左连接。

加入每个表后,只需选择您的标签列作为当前值 + 加入的列(如果它不为空)。

【讨论】:

我无法进行左连接或任何其他连接,因为特定标签评估的条件可能非常复杂并且基于多个列

以上是关于Spark SQL 添加列/更新累加值的主要内容,如果未能解决你的问题,请参考以下文章

将新列添加到现有表中并使用 PL/SQL 中游标中的值更新它们

如何在 SQL Server 的单个批次中为列添加默认值并使用默认值更新所有行?

需要创建列并将值(日期)添加到该列 SQL

Scala Spark,如何为列添加值

如何将列添加到 mapPartitions 内的 org.apache.spark.sql.Row

Spark SQL - IN 子句