如何在 PySpark 的 UDF 中返回“元组类型”?
Posted
技术标签:
【中文标题】如何在 PySpark 的 UDF 中返回“元组类型”?【英文标题】:How to return a "Tuple type" in a UDF in PySpark? 【发布时间】:2016-08-18 20:14:39 【问题描述】:所有data types in pyspark.sql.types
are:
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
"LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
我必须编写一个返回元组数组的 UDF(在 pyspark 中)。我给它的第二个参数是什么,它是 udf 方法的返回类型?这将是ArrayType(TupleType())
...
【问题讨论】:
您的标题问题似乎与正文不匹配。文档没有告诉您如何将返回值设置为“其他类型的容器类型”? @jonrsharpe 我已更改标题。希望它现在能代表身体。 【参考方案1】:在 Spark 中没有 TupleType
这样的东西。产品类型表示为structs
,具有特定类型的字段。例如,如果你想返回一个数组对(整数、字符串),你可以使用这样的模式:
from pyspark.sql.types import *
schema = ArrayType(StructType([
StructField("char", StringType(), False),
StructField("count", IntegerType(), False)
]))
示例用法:
from pyspark.sql.functions import udf
from collections import Counter
char_count_udf = udf(
lambda s: Counter(s).most_common(),
schema
)
df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"])
df.select("*", char_count_udf(df["value"])).show(2, False)
## +---+-----+-------------------------+
## |id |value|PythonUDF#<lambda>(value)|
## +---+-----+-------------------------+
## |1 |foo |[[o,2], [f,1]] |
## |2 |bar |[[r,1], [a,1], [b,1]] |
## +---+-----+-------------------------+
【讨论】:
您的回答有效,但我的情况有点复杂。我的返回数据是[('a1', [('b1', 1), ('b2', 2)]), ('a2', [('b1', 1), ('b2', 2)])]
类型,所以我将类型设为ArrayType(StructType([StructField("date", StringType(), False), ArrayType(StructType([StructField("hashId", StringType(), False), StructField("TimeSpent-Front", FloatType(), False), StructField("TimeSpent-Back", FloatType(), False)]))]))
,这使得 'ArrayType' 对象没有属性'name'...
StructType
需要StructFields
的序列,因此您不能单独使用ArrayTypes
。你需要StructField
存储ArrayType
。还有忠告——如果你发现自己在创建这样的结构,你可能应该重新考虑数据模型。没有 UDF 就很难处理深度嵌套的结构,而 Python UDF 远没有效率。
如何在 udf 中指定模式以返回列表。 F.udf(lambda start_date, end_date : [0,1] if start_date
【参考方案2】:
*** 一直在引导我回答这个问题,所以我想我会在这里添加一些信息。
从 UDF 返回简单类型:
from pyspark.sql.types import *
from pyspark.sql import functions as F
def get_df():
d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
df = sqlContext.createDataFrame(d, ['x', 'y'])
return df
df = get_df()
df.show()
# +---+---+
# | x| y|
# +---+---+
# |0.0|0.0|
# |0.0|3.0|
# |1.0|6.0|
# |1.0|9.0|
# +---+---+
func = udf(lambda x: str(x), StringType())
df = df.withColumn('y_str', func('y'))
func = udf(lambda x: int(x), IntegerType())
df = df.withColumn('y_int', func('y'))
df.show()
# +---+---+-----+-----+
# | x| y|y_str|y_int|
# +---+---+-----+-----+
# |0.0|0.0| 0.0| 0|
# |0.0|3.0| 3.0| 3|
# |1.0|6.0| 6.0| 6|
# |1.0|9.0| 9.0| 9|
# +---+---+-----+-----+
df.printSchema()
# root
# |-- x: double (nullable = true)
# |-- y: double (nullable = true)
# |-- y_str: string (nullable = true)
# |-- y_int: integer (nullable = true)
当整数不够用时:
df = get_df()
func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
df = df.withColumn('list', func('y'))
func = udf(lambda x: float(y): str(y) for y in range(int(x)),
MapType(FloatType(), StringType()))
df = df.withColumn('map', func('y'))
df.show()
# +---+---+--------------------+--------------------+
# | x| y| list| map|
# +---+---+--------------------+--------------------+
# |0.0|0.0| []| Map()|
# |0.0|3.0| [0, 0, 0]|Map(2.0 -> 2, 0.0...|
# |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
# +---+---+--------------------+--------------------+
df.printSchema()
# root
# |-- x: double (nullable = true)
# |-- y: double (nullable = true)
# |-- list: array (nullable = true)
# | |-- element: integer (containsNull = true)
# |-- map: map (nullable = true)
# | |-- key: float
# | |-- value: string (valueContainsNull = true)
从 UDF 返回复杂数据类型:
df = get_df()
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
df.show()
# +---+----------+
# | x| y[]|
# +---+----------+
# |0.0|[0.0, 3.0]|
# |1.0|[9.0, 6.0]|
# +---+----------+
schema = StructType([
StructField("min", FloatType(), True),
StructField("size", IntegerType(), True),
StructField("edges", ArrayType(FloatType()), True),
StructField("val_to_index", MapType(FloatType(), IntegerType()), True)
# StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))
])
def func(values):
mn = min(values)
size = len(values)
lst = sorted(values)[::-1]
val_to_index = x: i for i, x in enumerate(values)
return (mn, size, lst, val_to_index)
func = udf(func, schema)
dff = df.select('*', func('y[]').alias('complex_type'))
dff.show(10, False)
# +---+----------+------------------------------------------------------+
# |x |y[] |complex_type |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+
dff.printSchema()
# +---+----------+------------------------------------------------------+
# |x |y[] |complex_type |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+
将多个参数传递给 UDF:
df = get_df()
func = udf(lambda arr: arr[0]*arr[1],FloatType())
df = df.withColumn('x*y', func(F.array('x', 'y')))
# +---+---+---+
# | x| y|x*y|
# +---+---+---+
# |0.0|0.0|0.0|
# |0.0|3.0|0.0|
# |1.0|6.0|6.0|
# |1.0|9.0|9.0|
# +---+---+---+
代码纯粹用于演示目的,以上所有转换都可以在 Spark 代码中使用,并且会产生更好的性能。 正如上面评论中的@zero323,在 pyspark 中通常应避免使用 UDF;返回复杂类型应该让您考虑简化逻辑。
【讨论】:
【参考方案3】:对于 scala 版本而不是 python。 2.4版
import org.apache.spark.sql.types._
val testschema : StructType = StructType(
StructField("number", IntegerType) ::
StructField("Array", ArrayType(StructType(StructField("cnt_rnk", IntegerType) :: StructField("comp", StringType) :: Nil))) ::
StructField("comp", StringType):: Nil)
树形结构如下所示。
testschema.printTreeString
root
|-- number: integer (nullable = true)
|-- Array: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- cnt_rnk: integer (nullable = true)
| | |-- corp_id: string (nullable = true)
|-- comp: string (nullable = true)
【讨论】:
以上是关于如何在 PySpark 的 UDF 中返回“元组类型”?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Pyspark 中使用 @pandas_udf 返回多个数据帧?
C#关键字扫盲——Tuple(元组类) ValueTuple(值元组)