使用pyspark中的条件创建具有运行总量的列

Posted

技术标签:

【中文标题】使用pyspark中的条件创建具有运行总量的列【英文标题】:Create column with running total amount with conditions in pyspark 【发布时间】:2021-06-14 22:38:34 【问题描述】:

我有一个包含“account”、“order_date”和“amount”列的数据集。

我需要根据条件创建“余额”列。这里棘手的部分是“余额”列中的当前行取决于正在创建的同一列的上一行。

这是逻辑解释:

如果金额 > 0 则

  amount + [Row-1:balance]

其他

min([Row-1:balance],0) + amount

预期结果:

【问题讨论】:

我确实尝试了一种使用 spark-sql 的 case 语句的解决方案,但是在包含 2200 万条记录的整个数据集上运行时,该解决方案的性能很糟糕。 仅供参考,这里 [Row-1:balance] 的意思是,计算“余额”列的前一行值。 【参考方案1】:

使用 Pandas UDF (Spark >= 2.3)

import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt

        data = [
            'account': '1', 'order_date': '11/18/20', 'amount': -34.99,
            'account': '1', 'order_date': '10/28/20', 'amount': -4.99,
            'account': '1', 'order_date': '9/11/20', 'amount': 4.99,
            'account': '1', 'order_date': '9/2/20', 'amount': 9.98]
        # For simiplicity, creating a new column "balance" with 0.0 
        input_df = self._spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
        input_df.show()
        schema = input_df.schema

        @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
        def _get_running_total(input: pd.DataFrame):
            import os
            # To fix a bug in pyarrow newer versions
            os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1" 
            previous_balance = None
            input['order_date'] = pd.to_datetime(input['order_date'])
            df = input.sort_values(by=['order_date'], ascending=False)
            df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
            for i, row in df.iterrows():
                current_amount = row['amount']
                if i == 0:
                    running_total = current_amount
                else:
                    if current_amount > 0:
                        running_total = current_amount + previous_balance
                    else:
                        if previous_balance > 0:
                            previous_balance = 0
                        running_total = previous_balance + current_amount
                df._set_value(i, 'balance', running_total)
                previous_balance = running_total
            return df
        input_df.groupby('account').apply(_get_running_total).show()

已更新,在 3.2 上运行良好

import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt
from pyspark import SparkContext
from pyspark.sql import SparkSession


_spark = SparkSession.builder.appName("SparkByExamples.com").getOrCreate()  

data = [
    'account': '1', 'order_date': '11/18/20', 'amount': -34.99,
    'account': '1', 'order_date': '10/28/20', 'amount': -4.99,
    'account': '1', 'order_date': '9/11/20', 'amount': 4.99,
    'account': '1', 'order_date': '9/2/20', 'amount': 9.98]
# For simiplicity, creating a new column "balance" with 0.0 
input_df = _spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
input_df.show()
schema = input_df.schema

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_running_total(input: pd.DataFrame):
    import os
    # To fix a bug in pyarrow newer versions
    os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1" 
    previous_balance = None
    input['order_date'] = pd.to_datetime(input['order_date'])
    df = input.sort_values(by=['order_date'], ascending=False)
    df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
    for i, row in df.iterrows():
        current_amount = row['amount']
        if i == 0:
            running_total = current_amount
        else:
            if current_amount > 0:
                running_total = current_amount + previous_balance
            else:
                if previous_balance > 0:
                    previous_balance = 0
                running_total = previous_balance + current_amount
        df._set_value(i, 'balance', running_total)
        previous_balance = running_total
    return df

input_df.groupby('account').apply(_get_running_total).show()

【讨论】:

感谢您的回复。在这里您只是对记录进行排序,因为您只使用了一个帐户,我们应该将记录分组,然后使用 order_date 在每个组内排序,我们应该对每组记录应用此逻辑。 @Anand - 在计算每个组/帐户的运行总数之前,排序已经在 UDF 中进行。请检查! 它在最后一行,明白了。谢谢@veerat 这完全没问题,改动很小。感谢您在这方面的时间。【参考方案2】:

您是否尝试过使用 Windowing & spark lag 功能?

from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window

sc = SparkContext(appName="PRowDiffApp")
sqlc = SQLContext(sc)

rdd = sc.parallelize([(1, 65), (2, 66), (3, 65), (4, 68), (5, 71)])

df = sqlc.createDataFrame(rdd, ["id", "value"])

my_window = Window.partitionBy().orderBy("id")

df = df.withColumn("prev_value", F.lag(df.value).over(my_window))
df = df.withColumn("diff", F.when(F.isnull(df.value - df.prev_value), 0)
                              .otherwise(df.value - df.prev_value))

df.show()

【讨论】:

我确实以几种不同的方式尝试了 Window 和 lag 功能,我用谷歌搜索了很多可用的信息,但无法完全解决这个问题。如果我们能得到可靠的工作解决方案,那就太好了。

以上是关于使用pyspark中的条件创建具有运行总量的列的主要内容,如果未能解决你的问题,请参考以下文章

如何使用pyspark将具有多个可能值的Json数组列表转换为数据框中的列

数据框在多列上连接,pyspark中的列有一些条件[重复]

Pyspark:添加具有 groupby 平均值的列

Pyspark - 如何拆分具有 Datetime 类型的结构值的列?

从 pyspark 中的字典列创建数据框

在另一列pyspark中创建具有字符串长度的列