pyspark - 左连接,随机行匹配键
Posted
技术标签:
【中文标题】pyspark - 左连接,随机行匹配键【英文标题】:pyspark - left join with random row matching the key 【发布时间】:2021-10-01 15:56:07 【问题描述】:我正在寻找一种方法来加入 2 个数据帧,但随机行与键匹配。这个奇怪的请求是由于生成位置的计算非常长。 我想在 pyspark 中做一种“随机左连接”。
我有一个带有areaID
(字符串)和count
(整数)的数据框。 areaID
是独一无二的(大约 7k)。
+--------+-------+
| areaID | count |
+--------+-------+
| A | 10 |
| B | 30 |
| C | 1 |
| D | 25 |
| E | 18 |
+--------+-------+
我有第二个数据框,每个 areaID
有大约 1000 个预计算行,有 2 个位置列 x
(浮点数)和 y
(浮点数)。这个数据框大约有 700 万行。
+--------+------+------+
| areaID | x | y |
+--------+------+------+
| A | 0.0 | 0 |
| A | 0.1 | 0.7 |
| A | 0.3 | 1 |
| A | 0.1 | 0.3 |
| ... | | |
| E | 3.15 | 4.17 |
| E | 3.14 | 4.22 |
+--------+------+------+
我想以如下数据框结束:
+--------+------+------+
| areaID | x | y |
+--------+------+------+
| A | 0.1 | 0.32 | < row 1/10 - randomly picked where areaID are the same
| A | 0.0 | 0.18 | < row 2/10
| A | 0.09 | 0.22 | < row 3/10
| ... | | |
| E | 3.14 | 4.22 | < row 1/18
| ... | | |
+--------+------+------+
我的第一个想法是遍历第一个数据帧的每个areaID
,通过areaID
过滤第二个数据帧,并对该数据帧的count
行进行采样。问题是这在 7k 加载/过滤/采样过程中非常缓慢。
第二种方法是在areaID
上进行外部 连接,然后打乱数据帧(但看起来很复杂),应用等级并在rank
count 时保持但我不喜欢加载大量数据以供事后过滤的方法。
我想知道是否有办法使用“随机”left join 来做到这一点?在这种情况下,我会将每一行复制count
次并应用它。
非常感谢,
尼古拉斯
【问题讨论】:
【参考方案1】:可以将问题解释为第二个数据帧的stratified sampling,其中从每个子群体中抽取的样本数由第一个数据帧给出。
stratified sampling 有 Spark 功能。
df1 = ...
df2 = ...
#first calculate the fraction for each areaID based on the required number
#given in df1 and the number of rows for the areaID in df2
fractionRows = df2.groupBy("areaId").agg(F.count("areaId").alias("count2")) \
.join(df1, "areaId") \
.withColumn("fraction", F.col("count") / F.col("count2")) \
.select("areaId", "fraction") \
.collect()
fractions = f[0]:f[1] for f in fractionRows
#now run the statified samling
df2.stat.sampleBy("areaID", fractions).show()
这种方法有一点需要注意:由于 Spark 进行的采样是一个随机过程,因此第一个数据帧中给出的确切行数并不总是完全符合。
编辑:sampleBy
不支持大于 1.0 的分数。查看sampleBy
的Scala code 说明了原因:该函数被实现为带有一个随机变量的过滤器,该变量指示是否保持行。因此,返回单行的多个副本将不起作用。
类似的想法可用于支持大于 1.0 的分数:不使用过滤器,而是创建一个返回数组的 udf。该数组包含一个应包含在结果中的行的每个副本的条目。应用 udf 后,数组列被炸掉,然后被丢弃:
from pyspark.sql import functions as F
from pyspark.sql import types as T
fractions = 'A': 1.5, 'C': 0.5
def ff(stratum,x):
fraction = fractions.get(stratum, 0.0)
ret=[]
while fraction >= 1.0:
ret.append("x")
fraction = fraction - 1
if x < fraction:
ret.append("x")
return ret
f=F.udf(ff, T.ArrayType(T.StringType())).asNondeterministic()
seed=42
df2.withColumn("r", F.rand(seed)) \
.withColumn("r",f("areaID", F.col("r")))\
.withColumn("r", F.explode("r")) \
.drop("r") \
.show()
【讨论】:
似乎正是我所需要的,我会在当前计算完成后试一试(并接受答案)。我不知道 DataFrameStatFunctions 模块......似乎很方便:D 效果很好。我还有另一个子问题:)。如果我想更换,是否有解决方法?例如,如果我有 2.4 的一小部分? 我认为分数 > 1.0 有一个解决方法。我已将其添加到我的答案中 哇,太棒了:)。非常感谢你以上是关于pyspark - 左连接,随机行匹配键的主要内容,如果未能解决你的问题,请参考以下文章