如何在 Pyspark 中使用 @pandas_udf 返回多个数据帧?

Posted

技术标签:

【中文标题】如何在 Pyspark 中使用 @pandas_udf 返回多个数据帧?【英文标题】:How to return multiple dataframes using @pandas_udf in Pyspark? 【发布时间】:2021-01-28 17:19:59 【问题描述】:

我想为 Pyspark 创建 sklearn 的 train_test_split 函数。我正在使用 pandas udf 来创建这个函数

这就是我所做的。

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def load_dataset(dataset):
    
    feature_columns = cols
    label = 'y';
    X = dataset[feature_columns]
    Y = dataset[label]
 
    # splitting the dataset into train and test
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)
    print(X_train)
 
    return X_train, X_test, y_train, y_test 

我想要这些数据帧 X_train、X_test、y_train、y_test 分开。

我知道udf函数是这样调用的

df.groupby("key").apply(load_dataset).show()

但我不知道用什么代替 key 另外,这会返回单个数据帧,我想要四个。

【问题讨论】:

我 100% 确定 pyspark 已经有这个功能了。 如果你只是想分割你的数据框,你可以使用randomSplit 但是我不想使用randomsplit,实际上我想在Pyspark中使用sklearn的train_test_split函数。 你能给我推荐那个链接吗? @约翰螺柱 据我所知,pandas_udf 无法做到这一点。您不能返回 4 个 Spark 数据帧。请阅读文档Pandas Function APIs。 【参考方案1】:

有什么问题:

df = inputDF.cache()
a,b = df.randomSplit([0.5, 0.5])

对于顺序很重要的时间序列,请使用:

df = df.withColumn("rank", percent_rank().over(Window.partitionBy().orderBy("departure_time")))

train_df = df.where("rank <= .8").drop("rank", "departure_time")

【讨论】:

我想使用 sklearn 的 train_test_split 对 X_test 和 y_test 数据帧进行二次采样。 您能提供更多见解吗?为什么不像我上面指定的那样生成 X_test 和 y_test,然后对“subsample”再次执行相同的过程? 我认为 randomsplit 不适合二次采样,这就是为什么我想为 sklearn 的 train_test_split 创建一个 pandas udf 以便我可以使用它。【参考方案2】:

实际上我必须进行二次采样。这就是为什么我必须从 train_test_split 函数返回四个变量。但是我连接了 X_test 和 y_test 并返回了一个数据帧。

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def load_dataset(dataset):
    
    feature_columns = cols
    label = 'y';
    X = dataset[feature_columns]
    Y = dataset[label]
 
    # splitting the dataset into train and test
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)
    print(X_train)
    
    df_sample = pd.concat([X_test, y_test], axis=1)
 
    return df_sample

这段代码对我有用。

【讨论】:

以上是关于如何在 Pyspark 中使用 @pandas_udf 返回多个数据帧?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Pyspark 中使用 Scala 函数?

如何在 Databricks 的 PySpark 中使用在 Scala 中创建的 DataFrame

如何在 Pyspark 中使用 groupby 和数组元素?

如何在 PySpark 中使用窗口函数?

如何在 pyspark 中使用 MultiClassMetrics 计算 f 分数?

如何在 pyspark 中使用“不存在”的 SQL 条件?