将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较

Posted

技术标签:

【中文标题】将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较【英文标题】:compare array column of one dataframe to subsets of array column of another dataframe in scala 【发布时间】:2018-11-09 11:30:41 【问题描述】:

我有两个如下的数据框:

df1 = (Receipt_no: String , Items_no_set:Array[String])

+-----------+-------------------+
| Receipt_no| Items_no_set      |
+-----------+-------------------+
|        001| [123,124,125]     |
|        002| [501,502,503,504] |
|        003| [123,501,125]     |
+-----------+-------------------+


df2 = (product_no: String , product_items_set:Array[String])

+-----------+-------------------+
| product_no| product_items_set |
+-----------+-------------------+
|        909| [123,124]         |
|        908| [501,502,503]     |
|        907| [123,501,125]     |
+-----------+-------------------+

现在我想将 df1(Items_no_set) 与 df2(product_items_set) 进行比较,如果找到的匹配项返回 df3(Receipt_no,Items_no_set,product_no)。

如果在上述情况下找不到匹配项,我想创建 df1(Items_no_set) 的子集,然后比较是否找到匹配项

我的预期输出:

+-----------+-------------------+-----------+
| Receipt_no| Items_no_set      | product_no|
+-----------+-------------------+-----------+
|        001| [123,124]         |   909     |
|        002| [501,502,503,504] |   908     |
|        003| [123,501,125]     |   907     |
+-----------+-------------------+-----------+

我正在努力实现上述步骤和我的预期输出。任何帮助将不胜感激。

【问题讨论】:

【参考方案1】:

由于 df1 和 df2 之间没有匹配的 key,我们必须做 crossJoin。看看这个 rdd 解决方案:

scala> val df1 = Seq(
     |       ("001",Array(123,124,125)),
     |       ("002",Array(501,502,503,504)),
     |       ("003",Array(123,501,125)) ).toDF("receipt_no","items_no_set")
df1: org.apache.spark.sql.DataFrame = [receipt_no: string, items_no_set: array<int>]

scala> val df2 = Seq(
     |       ("909",Array(123,124)),
     |       ("908",Array(501,502,503)),
     |       ("907",Array(123,501,125)) ).toDF("product_no","product_items_set")
df2: org.apache.spark.sql.DataFrame = [product_no: string, product_items_set: array<int>]

scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._

scala> val df3 = df1.crossJoin(df2)
df3: org.apache.spark.sql.DataFrame = [receipt_no: string, items_no_set: array<int> ... 2 more fields]

scala> val rdd2= df3.rdd.filter( x => 
     |             val items = x.getAs[scala.collection.mutable.WrappedArray[Int]]("items_no_set").toArray;
     |             val prds = x.getAs[scala.collection.mutable.WrappedArray[Int]]("product_items_set").toArray;
     |             val chk = prds.intersect(items).length == prds.length
     |             ( chk == true )
     |             
     |             ).map( x => Row(x(0),x(1),x(2)))
rdd2: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = MapPartitionsRDD[76] at map at <console>:50

scala> val schema = df1.schema.add(StructField("product_no",StringType))
schema: org.apache.spark.sql.types.StructType = StructType(StructField(receipt_no,StringType,true), StructField(items_no_set,ArrayType(IntegerType,false),true), StructField(product_no,StringType,true))

scala> spark.createDataFrame(rdd2,schema).show(false)
+----------+--------------------+----------+
|receipt_no|items_no_set        |product_no|
+----------+--------------------+----------+
|001       |[123, 124, 125]     |909       |
|002       |[501, 502, 503, 504]|908       |
|003       |[123, 501, 125]     |907       |
+----------+--------------------+----------+


scala>

【讨论】:

以上是关于将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较的主要内容,如果未能解决你的问题,请参考以下文章

将 MultiIndex Pandas 数据帧乘以来自另一个数据帧的多个标量

如何根据一个数据帧中的列值和R中另一个数据帧的列标题名称有条件地创建新列

在 Apache Spark (Scala) 上获取两个数据帧的差异

scala 前向参考扩展了值数据帧的定义

根据来自其他数据帧的位置条件在数据帧上编写选择查询,scala

获取一行数据帧的字段值 - Spark Scala