从 UDF 返回 StructType 的 ArrayType 时出错(并在多个 UDF 中使用单个函数)

Posted

技术标签:

【中文标题】从 UDF 返回 StructType 的 ArrayType 时出错(并在多个 UDF 中使用单个函数)【英文标题】:Error when returning an ArrayType of StructType from UDF (and using a single function in multiple UDFs) 【发布时间】:2019-08-07 11:10:17 【问题描述】:

(编辑)更改了字段名称(从 foo、bar、... 到 name 和 city),因为旧的命名令人困惑

我需要在多个 UDF 中使用单个函数并根据输入返回不同的结构。

我的实现的这个简化版本基本上可以满足我的要求:

from pyspark.sql.types import IntegerType, StructType, StringType
from pyspark.sql.functions import when, col

df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')

struct_one = StructType().add('name', StringType(), True)
struct_not_one = StructType().add('city', StringType(), True)

def select(id):
  if id == 1:
    return 'name': 'Alice'
  else:
    return 'city': 'Seattle'

one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)

df = df.withColumn('one', when((col('id') == 1), one_udf(col('id'))))\
       .withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))

display(df)   

(编辑)输出:

id  one               not_one
1   "name":"Alice"  null
2   null              "city":"Seattle"
3   null              "city":"Seattle"

但是,不幸的是,返回 ArrayType 为 StructType 的相同代码失败了:

from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col

df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')

struct_one = StructType().add('name', StringType(), True)
struct_not_one = ArrayType(StructType().add('city', StringType(), True))

def select(id):
  if id == 1:
    return 'name': 'Alice'
  else:
    return ['city': 'Seattle', 'city': 'Milan']

one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)

df = df.withColumn('one', when((col('id') == 1), one_udf(col('id'))))\
       .withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))

display(df)      

错误信息是:

ValueError: 带有 StructType 的意外元组“名称”

(编辑)所需的输出是:

id  one                 not_one
1   "name":"Alice"    null
2   null                ["city":"Seattle","city":"Milan"]
3   null                ["city":"Seattle","city":"Milan"]

但是,其他类型(StringType、IntegerType、...)的返回和 ArrayType 也可以。

在多个 UDF 中不使用单个函数时也可以返回 StructType 数组:

from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col

df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')

struct_not_one = ArrayType(StructType().add('city', StringType(), True))

def select(id):
    return ['city': 'Seattle', 'city': 'Milan']

not_one_udf = udf(select, struct_not_one)

df = df.withColumn('not_one', when((col('id') != 1), not_one_udf(col('id'))))

display(df)   

(编辑)输出:

id  not_one
1   null
2   ["city":"Seattle","city":"Milan"]
3   ["city":"Seattle","city":"Milan"]

为什么返回 StructType 的 ArrayType 并使用单个函数的多个 UDF 不起作用?

谢谢!

【问题讨论】:

你能分享一个输入和期望输出的例子吗? 我已经编辑了原始帖子中的输出(+期望输出)。输入只是由以下行创建的小样本数据框: df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id') 试试这个:df.withColumn("one", when(col("id") === 1, typedLit(Map("foo" -> "bar")))) .withColumn("not_one", when(col("id") =!= 1, typedLit(Map("foo" -> ("bar", "baz")))))。您的 "foo" : "bar" 是一个唯一的字典(按键),"foo" 应该只出现一次,同时具有两个值:"bar" & "baz" 感谢尼尔的评论。我更改了名称,因为它有点令人困惑。不幸的是,我不能使用 typedLit(),因为在我的实际实现中,select() 函数会动态创建返回值。在示例中,我将其简化为返回的硬编码值。 【参考方案1】:

“Spark SQL(包括 SQL 以及 DataFrame 和 Dataset API)不保证子表达式的求值顺序... 因此,依赖布尔表达式的副作用或求值顺序以及 WHERE 和 HAVING 子句的顺序是危险的,因为在查询优化和规划期间可以重新排序此类表达式和子句。具体来说,如果 UDF 依赖 SQL 中的短路语义进行空值检查,则无法保证在调用 UDF 之前会发生空值检查。”

见Evaluation order and null checking

为了保持您的 udf 通用性,您可以将“when filter”推送到您的 udf 中:

from pyspark.sql.types import IntegerType, StructType, StringType, ArrayType
from pyspark.sql.functions import when, col, lit

df = spark.createDataFrame([1, 2, 3], IntegerType()).toDF('id')

struct_one = StructType().add('name', StringType(), True)
struct_not_one = ArrayType(StructType().add('city', StringType(), True))

def select(id, test):

  if eval(test.format(id)) is False:
    return None

  if id == 1:
    return 'name': 'Alice'
  else:
    return ['city': 'Seattle', 'city': 'Milan']

one_udf = udf(select, struct_one)
not_one_udf = udf(select, struct_not_one)

df = df.withColumn('one', one_udf(col('id'), lit(' == 1')))\
       .withColumn('not_one', not_one_udf(col('id'), lit(' != 1')))

display(df)    

【讨论】:

以上是关于从 UDF 返回 StructType 的 ArrayType 时出错(并在多个 UDF 中使用单个函数)的主要内容,如果未能解决你的问题,请参考以下文章

在 StructType 数组上应用 UDF

Apache Spark SQL StructType 和 UDF

UDF 的输入类型应该是啥类型的列 - StructType 数组或“null”?

如何将火花行(StructType)投射到scala案例类

Spark之UDF

pyspark 在 udf 中获取结构数据类型的字段名称