如何在 PySpark SQL when() 子句中使用聚合值?

Posted

技术标签:

【中文标题】如何在 PySpark SQL when() 子句中使用聚合值?【英文标题】:How do you use aggregated values within PySpark SQL when() clause? 【发布时间】:2021-12-03 23:24:12 【问题描述】:

我正在尝试学习 PySpark,并尝试学习如何使用 SQL when() 子句更好地对我的数据进行分类。 (请参阅此处:https://sparkbyexamples.com/spark/spark-case-when-otherwise-example/)我似乎无法解决的是如何将实际标量值插入 when() 条件以明确进行比较。似乎聚合函数返回的表格值比实际的 float() 类型更多。 我不断收到此错误消息unsupported operand type(s) for -: 'method' and 'method' 当我尝试运行函数来聚合原始数据框中的另一列时,我注意到了结果似乎不像表格那样是平面缩放器(agg(select(f.stddev("Col")) 给出的结果如下:"DataFrame[stddev_samp(TAXI_OUT): double]" )如果你想复制,这是我试图完成的一个示例,我想知道你如何在 when() 子句中获得像标准偏差和平均值这样的聚合值,所以你可以使用它来对您的新列进行分类:

samp = spark.createDataFrame(
    [("A","A1",4,1.25),("B","B3",3,2.14),("C","C2",7,4.24),("A","A3",4,1.25),("B","B1",3,2.14),("C","C1",7,4.24)],
    ["Category","Sub-cat","quantity","cost"])
  
    psMean = samp.agg('quantity':'mean')
    psStDev = samp.agg('quantity':'stddev')

    psCatVect = samp.withColumn('quant_category',.when(samp['quantity']<=(psMean-psStDev),'small').otherwise('not small')) ```  

【问题讨论】:

试试这个 psCatVect = samp.withColumn('quant_category',.when(samp['quantity'] 【参考方案1】:

您示例中的 psMean 和 psStdev 是数据帧,您需要使用 collect() 方法来提取标量值

psMean = samp.agg('quantity':'mean').collect()[0][0]
psStDev = samp.agg('quantity':'stddev').collect()[0][0]

【讨论】:

谢谢!我一直在搜索并且有一次收集(),但没有意识到我需要使用索引。我假设由于这是一个数据框,因此需要同时指定第一行和第一列;因此“[0][0]”?【参考方案2】:

您还可以将所有统计信息创建为 pandas DataFrame 的一个变量,并稍后在 pyspark 代码中引用它:

from pyspark.sql import functions as F

stats = (
    samp.select(
        F.mean("quantity").alias("mean"), 
        F.stddev("quantity").alias("std")
    ).toPandas()
)


(
    samp.withColumn('quant_category', 
                F.when(
                    samp['quantity'] <= stats["mean"].item() - stats["std"].item(), 
                    'small')
                .otherwise('not small')
               )
    .toPandas()
)

【讨论】:

谢谢 Pav3k,那么在这种情况下“item()”会替换 [][] 吗? 是的,那是因为数据类型不同。 stats 是只有一行的 pandas DataFrame,因此如果您键入 stats["mean"] 您返回了大小为 1 的 pandas Series。因为这只有 1 个元素,您可以使用 stats["mean"].item() 来提取标量价值。如果那里有超过 1 行,该 .item() 方法将返回 ValueError: can only convert a array of size 1 to a Python scalar.

以上是关于如何在 PySpark SQL when() 子句中使用聚合值?的主要内容,如果未能解决你的问题,请参考以下文章

PySpark:when子句中的多个条件

如何在 pyspark.sql.functions.when() 中使用多个条件?

如何在 PySpark 中编写条件正则表达式替换?

SQL 中的自引用 CASE WHEN 子句

Oracle SQL 不同的 where 子句与 case when

PySpark 单元测试方法