使用pyspark中的条件创建具有运行总量的列
Posted
技术标签:
【中文标题】使用pyspark中的条件创建具有运行总量的列【英文标题】:Create column with running total amount with conditions in pyspark 【发布时间】:2021-06-14 22:38:34 【问题描述】:我有一个包含“account”、“order_date”和“amount”列的数据集。
我需要根据条件创建“余额”列。这里棘手的部分是“余额”列中的当前行取决于正在创建的同一列的上一行。
这是逻辑解释:
如果金额 > 0 则
amount + [Row-1:balance]
其他
min([Row-1:balance],0) + amount
预期结果:
【问题讨论】:
我确实尝试了一种使用 spark-sql 的 case 语句的解决方案,但是在包含 2200 万条记录的整个数据集上运行时,该解决方案的性能很糟糕。 仅供参考,这里 [Row-1:balance] 的意思是,计算“余额”列的前一行值。 【参考方案1】:使用 Pandas UDF (Spark >= 2.3)
import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt
data = [
'account': '1', 'order_date': '11/18/20', 'amount': -34.99,
'account': '1', 'order_date': '10/28/20', 'amount': -4.99,
'account': '1', 'order_date': '9/11/20', 'amount': 4.99,
'account': '1', 'order_date': '9/2/20', 'amount': 9.98]
# For simiplicity, creating a new column "balance" with 0.0
input_df = self._spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
input_df.show()
schema = input_df.schema
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_running_total(input: pd.DataFrame):
import os
# To fix a bug in pyarrow newer versions
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1"
previous_balance = None
input['order_date'] = pd.to_datetime(input['order_date'])
df = input.sort_values(by=['order_date'], ascending=False)
df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
for i, row in df.iterrows():
current_amount = row['amount']
if i == 0:
running_total = current_amount
else:
if current_amount > 0:
running_total = current_amount + previous_balance
else:
if previous_balance > 0:
previous_balance = 0
running_total = previous_balance + current_amount
df._set_value(i, 'balance', running_total)
previous_balance = running_total
return df
input_df.groupby('account').apply(_get_running_total).show()
已更新,在 3.2 上运行良好
import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt
from pyspark import SparkContext
from pyspark.sql import SparkSession
_spark = SparkSession.builder.appName("SparkByExamples.com").getOrCreate()
data = [
'account': '1', 'order_date': '11/18/20', 'amount': -34.99,
'account': '1', 'order_date': '10/28/20', 'amount': -4.99,
'account': '1', 'order_date': '9/11/20', 'amount': 4.99,
'account': '1', 'order_date': '9/2/20', 'amount': 9.98]
# For simiplicity, creating a new column "balance" with 0.0
input_df = _spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
input_df.show()
schema = input_df.schema
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_running_total(input: pd.DataFrame):
import os
# To fix a bug in pyarrow newer versions
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1"
previous_balance = None
input['order_date'] = pd.to_datetime(input['order_date'])
df = input.sort_values(by=['order_date'], ascending=False)
df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
for i, row in df.iterrows():
current_amount = row['amount']
if i == 0:
running_total = current_amount
else:
if current_amount > 0:
running_total = current_amount + previous_balance
else:
if previous_balance > 0:
previous_balance = 0
running_total = previous_balance + current_amount
df._set_value(i, 'balance', running_total)
previous_balance = running_total
return df
input_df.groupby('account').apply(_get_running_total).show()
【讨论】:
感谢您的回复。在这里您只是对记录进行排序,因为您只使用了一个帐户,我们应该将记录分组,然后使用 order_date 在每个组内排序,我们应该对每组记录应用此逻辑。 @Anand - 在计算每个组/帐户的运行总数之前,排序已经在 UDF 中进行。请检查! 它在最后一行,明白了。谢谢@veerat 这完全没问题,改动很小。感谢您在这方面的时间。【参考方案2】:您是否尝试过使用 Windowing & spark lag 功能?
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window
sc = SparkContext(appName="PRowDiffApp")
sqlc = SQLContext(sc)
rdd = sc.parallelize([(1, 65), (2, 66), (3, 65), (4, 68), (5, 71)])
df = sqlc.createDataFrame(rdd, ["id", "value"])
my_window = Window.partitionBy().orderBy("id")
df = df.withColumn("prev_value", F.lag(df.value).over(my_window))
df = df.withColumn("diff", F.when(F.isnull(df.value - df.prev_value), 0)
.otherwise(df.value - df.prev_value))
df.show()
【讨论】:
我确实以几种不同的方式尝试了 Window 和 lag 功能,我用谷歌搜索了很多可用的信息,但无法完全解决这个问题。如果我们能得到可靠的工作解决方案,那就太好了。以上是关于使用pyspark中的条件创建具有运行总量的列的主要内容,如果未能解决你的问题,请参考以下文章
如何使用pyspark将具有多个可能值的Json数组列表转换为数据框中的列