使用条件结果列连接 PySpark 数据帧
Posted
技术标签:
【中文标题】使用条件结果列连接 PySpark 数据帧【英文标题】:Joining PySpark dataframes with conditional result column 【发布时间】:2021-03-01 20:38:58 【问题描述】:我有这些表:
df1 df2
+---+------------+ +---+---------+
| id| many_cols| | id|criterion|
+---+------------+ +---+---------+
| 1|lots_of_data| | 1| false|
| 2|lots_of_data| | 1| true|
| 3|lots_of_data| | 1| true|
+---+------------+ | 3| false|
+---+---------+
我打算在df1
中创建额外的列:
+---+------------+------+
| id| many_cols|result|
+---+------------+------+
| 1|lots_of_data| 1|
| 2|lots_of_data| null|
| 3|lots_of_data| 0|
+---+------------+------+
如果df2
中有对应的true
,result
应该是1
如果df2
中没有对应的true
,result
应该是0
@987654332如果df2
中没有对应的id
,@应该是null
我想不出一种有效的方法来做到这一点。加入后,我只遇到第三个条件:
df = df1.join(df2, 'id', 'full')
df.show()
# +---+------------+---------+
# | id| many_cols|criterion|
# +---+------------+---------+
# | 1|lots_of_data| false|
# | 1|lots_of_data| true|
# | 1|lots_of_data| true|
# | 3|lots_of_data| false|
# | 2|lots_of_data| null|
# +---+------------+---------+
PySpark 数据帧是这样创建的:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
(2, 'lots_of_data'),
(3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
(1, True),
(1, True),
(3, None)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)
【问题讨论】:
【参考方案1】:一个简单的方法是 groupby df2
通过id
获得最大的criterion
与df1
的连接,这样可以减少要连接的行数。如果至少有一个对应的真值,则布尔列的最大值为真:
from pyspark.sql import functions as F
df2_group = df2.groupBy("id").agg(F.max("criterion").alias("criterion"))
result = df1.join(df2_group, ["id"], "left").withColumn(
"result",
F.col("criterion").cast("int")
).drop("criterion")
result.show()
#+---+------------+------+
#| id| many_cols|result|
#+---+------------+------+
#| 1|lots_of_data| 1|
#| 3|lots_of_data| 0|
#| 2|lots_of_data| null|
#+---+------------+------+
【讨论】:
【参考方案2】:您可以尝试关联子查询以从 df2 获取最大布尔值,并将其转换为整数。
df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')
df = spark.sql("""
select
df1.*,
(select int(max(criterion)) from df2 where df1.id = df2.id) as result
from df1
""")
df.show()
+---+------------+------+
| id| many_cols|result|
+---+------------+------+
| 1|lots_of_data| 1|
| 3|lots_of_data| 0|
| 2|lots_of_data| null|
+---+------------+------+
【讨论】:
【参考方案3】:查看此解决方案。加入后。您可以根据您的要求使用多个条件检查,并使用 when 子句相应地分配值,然后按 id 和其他列获取结果分组的最大值。如果您只使用 id 作为分区,您也可以使用窗口函数来计算结果的最大值。
from pyspark.sql import functions as F
from pyspark.sql.window import Window
df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
(2, 'lots_of_data'),
(3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
(1, True),
(1, True),
(3, False)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)
df2_mod =df2.withColumnRenamed("id", "id_2")
df3=df1.join(df2_mod, on=df1.id== df2_mod.id_2, how='left')
cond1 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==1)
cond2 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==0)
cond3 = (F.col("id_2").isNull())
df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
.groupBy("id", "many_cols").agg(F.max(F.col("result")).alias("result")).orderBy("id").show()
Result:
------
+---+------------+------+
| id| many_cols|result|
+---+------------+------+
| 1|lots_of_data| 1|
| 2|lots_of_data| null|
| 3|lots_of_data| 0|
+---+------------+------+
使用窗口函数
w=Window().partitionBy("id")
df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
.select("id", "many_cols", F.max("result").over(w).alias("result")).drop_duplicates().show()
【讨论】:
非常感谢您将条件重构为单独变量的绝妙想法。这是我真实用例中的关键。然而,由于最终版本演变成一个完全不同的东西,我不能接受这个作为答案。我投了一个赞成票,但你至少应该得到两个 :)【参考方案4】:我必须合并建议答案的想法,以获得最适合我的解决方案。
# The `cond` variable is very useful, here it represents several complex conditions
cond = F.col('criterion') == True
df2_grp = df2.select(
'id',
F.when(cond, 1).otherwise(0).alias('c')
).groupBy('id').agg(F.max(F.col('c')).alias('result'))
df = df1.join(df2_grp, 'id', 'left')
df.show()
#+---+------------+------+
#| id| many_cols|result|
#+---+------------+------+
#| 1|lots_of_data| 1|
#| 3|lots_of_data| 0|
#| 2|lots_of_data| null|
#+---+------------+------+
【讨论】:
以上是关于使用条件结果列连接 PySpark 数据帧的主要内容,如果未能解决你的问题,请参考以下文章
Pyspark:内部连接两个 pyspark 数据帧并从第一个数据帧中选择所有列,从第二个数据帧中选择几列