pyspark 检查每个名字是不是有3个数据

Posted

技术标签:

【中文标题】pyspark 检查每个名字是不是有3个数据【英文标题】:pyspark check whether each name has 3 datapyspark 检查每个名字是否有3个数据 【发布时间】:2021-07-19 04:36:51 【问题描述】:

在 pyspark 中,我有一个 DataFrame,如下所示。我想检查每个名称是否有 3 个动作数据(0、1、2)。如果有缺失,则添加新行,分数列设置为0,其他列不变(例如:str1、str2、str3)。

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
+-----+--------+--------+--------+-------+-------+

比如名字B没有动作2,添加新的一行数据如下

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0|<---- new row data
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
+-----+--------+--------+--------+-------+-------+

也有可能一个名字只有一个行数据,需要新增两个行数据。

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0| 
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
|  D  | str_D1 | str_D2 | str_D3 |      0|     45|
+-----+--------+--------+--------+-------+-------+

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0| 
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
|  D  | str_D1 | str_D2 | str_D3 |      0|     45|
|  D  | str_D1 | str_D2 | str_D3 |      1|      0|<---- new row data
|  D  | str_D1 | str_D2 | str_D3 |      2|      0|<---- new row data
+-----+--------+--------+--------+-------+-------+

我是 pyspark 的新手,不知道如何执行此操作。 感谢您的帮助。

【问题讨论】:

【参考方案1】:

使用 UDF 的解决方案

from pyspark.sql import functions as F, types as T

@F.udf(T.MapType(T.StringType(), T.IntegerType()))
def add_missing_values(values):
    return i: values.get(i, 0) for i in range(3)

df = (
    df.groupBy("name", "str1", "str2", "str3")
    .agg(
        F.map_from_entries(F.collect_list(F.struct("action", "score"))).alias("values")
    )
    .withColumn("values", add_missing_values(F.col("values")))
    .select(
        "name", "str1", "str2", "str3", F.explode("values").alias("action", "score")
    )
)

df.show()

+----+------+------+------+------+-----+                                        
|name|  str1|  str2|  str3|action|score|
+----+------+------+------+------+-----+
|   A|str_A1|str_A2|str_A3|     0|    2|
|   A|str_A1|str_A2|str_A3|     1|    6|
|   A|str_A1|str_A2|str_A3|     2|   74|
|   B|str_B1|str_B2|str_B3|     0|   59|
|   B|str_B1|str_B2|str_B3|     1|   18|
|   B|str_B1|str_B2|str_B3|     2|    0|<---- new row data
|   C|str_C1|str_C2|str_C3|     0|    3|
|   C|str_C1|str_C2|str_C3|     1|   33|
|   C|str_C1|str_C2|str_C3|     2|    3|
|   D|str_D1|str_D2|str_D3|     0|   45|
|   D|str_D1|str_D2|str_D3|     1|    0|<---- new row data
|   D|str_D1|str_D2|str_D3|     2|    0|<---- new row data
+----+------+------+------+------+-----+

完整的 Spark 解决方案:

df = (
    df.groupBy("name", "str1", "str2", "str3")
    .agg(
        F.map_from_entries(F.collect_list(F.struct("action", "score"))).alias("values")
    )
    .withColumn(
        "values",
        F.map_from_arrays(
            F.array([F.lit(i) for i in range(3)]),
            F.array(
                [F.coalesce(F.col("values").getItem(i), F.lit(0)) for i in range(3)]
            ),
        ),
    )
    .select(
        "name", "str1", "str2", "str3", F.explode("values").alias("action", "score")
    )
)

【讨论】:

以上是关于pyspark 检查每个名字是不是有3个数据的主要内容,如果未能解决你的问题,请参考以下文章

检查列 pyspark df 的值是不是存在于其他列 pyspark df

Pyspark:检查数据框中是不是存在列[重复]

Pyspark:通过检查值是不是存在来聚合数据(不是计数或总和)

检查一列是不是与pyspark中的groupby连续

如何检查 Dataproc 上 pyspark 作业的每个执行程序/节点内存使用指标?

如何检查 Pyspark Dataframe 中是不是存在列表交集