如何在 PySpark SQL when() 子句中使用聚合值?
Posted
技术标签:
【中文标题】如何在 PySpark SQL when() 子句中使用聚合值?【英文标题】:How do you use aggregated values within PySpark SQL when() clause? 【发布时间】:2021-12-03 23:24:12 【问题描述】:我正在尝试学习 PySpark,并尝试学习如何使用 SQL when() 子句更好地对我的数据进行分类。 (请参阅此处:https://sparkbyexamples.com/spark/spark-case-when-otherwise-example/)我似乎无法解决的是如何将实际标量值插入 when() 条件以明确进行比较。似乎聚合函数返回的表格值比实际的 float() 类型更多。 我不断收到此错误消息unsupported operand type(s) for -: 'method' and 'method' 当我尝试运行函数来聚合原始数据框中的另一列时,我注意到了结果似乎不像表格那样是平面缩放器(agg(select(f.stddev("Col")) 给出的结果如下:"DataFrame[stddev_samp(TAXI_OUT): double]" )如果你想复制,这是我试图完成的一个示例,我想知道你如何在 when() 子句中获得像标准偏差和平均值这样的聚合值,所以你可以使用它来对您的新列进行分类:
samp = spark.createDataFrame(
[("A","A1",4,1.25),("B","B3",3,2.14),("C","C2",7,4.24),("A","A3",4,1.25),("B","B1",3,2.14),("C","C1",7,4.24)],
["Category","Sub-cat","quantity","cost"])
psMean = samp.agg('quantity':'mean')
psStDev = samp.agg('quantity':'stddev')
psCatVect = samp.withColumn('quant_category',.when(samp['quantity']<=(psMean-psStDev),'small').otherwise('not small')) ```
【问题讨论】:
试试这个 psCatVect = samp.withColumn('quant_category',.when(samp['quantity'] 【参考方案1】:您示例中的 psMean 和 psStdev 是数据帧,您需要使用 collect() 方法来提取标量值
psMean = samp.agg('quantity':'mean').collect()[0][0]
psStDev = samp.agg('quantity':'stddev').collect()[0][0]
【讨论】:
谢谢!我一直在搜索并且有一次收集(),但没有意识到我需要使用索引。我假设由于这是一个数据框,因此需要同时指定第一行和第一列;因此“[0][0]”?【参考方案2】:您还可以将所有统计信息创建为 pandas DataFrame 的一个变量,并稍后在 pyspark 代码中引用它:
from pyspark.sql import functions as F
stats = (
samp.select(
F.mean("quantity").alias("mean"),
F.stddev("quantity").alias("std")
).toPandas()
)
(
samp.withColumn('quant_category',
F.when(
samp['quantity'] <= stats["mean"].item() - stats["std"].item(),
'small')
.otherwise('not small')
)
.toPandas()
)
【讨论】:
谢谢 Pav3k,那么在这种情况下“item()”会替换 [][] 吗? 是的,那是因为数据类型不同。 stats 是只有一行的 pandas DataFrame,因此如果您键入 stats["mean"] 您返回了大小为 1 的 pandas Series。因为这只有 1 个元素,您可以使用 stats["mean"].item() 来提取标量价值。如果那里有超过 1 行,该 .item() 方法将返回 ValueError: can only convert a array of size 1 to a Python scalar.以上是关于如何在 PySpark SQL when() 子句中使用聚合值?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 pyspark.sql.functions.when() 中使用多个条件?