在新列上过滤 Spark DataFrame

Posted

技术标签:

【中文标题】在新列上过滤 Spark DataFrame【英文标题】:Filtering Spark DataFrame on new column 【发布时间】:2017-01-30 21:43:13 【问题描述】:

上下文:我的数据集太大而无法放入内存中,我正在训练 Keras RNN。我在 AWS EMR 集群上使用 PySpark 批量训练模型,这些模型小到可以存储在内存中。我无法使用elephas 实现分布式模型,我怀疑这与我的模型是有状态的有关。不过我不完全确定。

每个用户的数据框都有一行,从安装之日起经过的天数从 0 到 29。查询数据库后,我对数据框进行了一些操作:

query = """WITH max_days_elapsed AS (
        SELECT user_id,
            max(days_elapsed) as max_de
        FROM table
        GROUP BY user_id
        )
        SELECT table.*
        FROM table
            LEFT OUTER JOIN max_days_elapsed USING (user_id)
        WHERE max_de = 1
            AND days_elapsed < 1"""

df = read_from_db(query) #this is just a custom function to query our database

#Create features vector column
assembler = VectorAssembler(inputCols=features_list, outputCol="features")
df_vectorized = assembler.transform(df)

#Split users into train and test and assign batch number
udf_randint = udf(lambda x: np.random.randint(0, x), IntegerType())
training_users, testing_users = df_vectorized.select("user_id").distinct().randomSplit([0.8,0.2],123)
training_users = training_users.withColumn("batch_number", udf_randint(lit(N_BATCHES)))

#Create and sort train and test dataframes
train = df_vectorized.join(training_users, ["user_id"], "inner").select(["user_id", "days_elapsed","batch_number","features", "kpi1", "kpi2", "kpi3"])
train = train.sort(["user_id", "days_elapsed"])
test = df_vectorized.join(testing_users, ["user_id"], "inner").select(["user_id","days_elapsed","features", "kpi1", "kpi2", "kpi3"])
test = test.sort(["user_id", "days_elapsed"])

我遇到的问题是,如果没有缓存火车,我似乎无法过滤 batch_number。我可以过滤我们数据库中原始数据集中的任何列,但不能过滤我在查询数据库后在 pyspark 中生成的任何列:

这个:train.filter(train["days_elapsed"] == 0).select("days_elapsed").distinct.show() 只返回 0。

但是,所有这些都返回 0 到 9 之间的所有批号,没有任何过滤:

train.filter(train["batch_number"] == 0).select("batch_number").distinct().show() train.filter(train.batch_number == 0).select("batch_number").distinct().show() train.filter("batch_number = 0").select("batch_number").distinct().show() train.filter(col("batch_number") == 0).select("batch_number").distinct().show()

这也不起作用:

train.createOrReplaceTempView("train_table")
batch_df = spark.sql("SELECT * FROM train_table WHERE batch_number = 1")
batch_df.select("batch_number").distinct().show()

如果我先执行 train.cache() 所有这些工作。这是绝对必要的还是有办法在不缓存的情况下做到这一点?

【问题讨论】:

【参考方案1】:

Spark >= 2.3(? - 取决于 SPARK-22629 的进度)

应该可以使用asNondeterministic 方法禁用某些优化。

火花

不要使用 UDF 生成随机数。首先引用the docs:

用户定义的函数必须是确定性的。由于优化,可能会消除重复调用,或者该函数的调用次数甚至可能比查询中出现的次数多。

即使不是 UDF,也有 Spark 的微妙之处,这使得在处理单个记录时几乎不可能实现这一权利。

Spark 已经提供rand:

使用来自 U[0.0, 1.0] 的独立且同分布 (i.i.d.) 样本生成一个随机列。

randn

使用来自标准正态分布的独立且同分布 (i.i.d.) 样本生成一列。

可用于构建更复杂的生成器函数。

注意

您的代码可能存在其他一些问题,但这从一开始就让人无法接受(Random numbers generation in PySpark、pyspark. Transformer that generates a random number generates always the same number)。

【讨论】:

以上是关于在新列上过滤 Spark DataFrame的主要内容,如果未能解决你的问题,请参考以下文章

具有多列的 XSLT 区域主体,强制块在新列上开始

在 Pandas 数据框中找到最小值并在新列上添加标签

在另一列上查找最近的时间戳并在新列中添加值 PySpark

Spark - 在数据集的几列上应用 UDF 并形成新列

spark:模式更改——如果存在,则转换和过滤列上的数据框;如果没有就不要

在镶木地板的地图类型列上使用 spark-sql 过滤下推