将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较
Posted
技术标签:
【中文标题】将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较【英文标题】:compare array column of one dataframe to subsets of array column of another dataframe in scala 【发布时间】:2018-11-09 11:30:41 【问题描述】:我有两个如下的数据框:
df1 = (Receipt_no: String , Items_no_set:Array[String])
+-----------+-------------------+
| Receipt_no| Items_no_set |
+-----------+-------------------+
| 001| [123,124,125] |
| 002| [501,502,503,504] |
| 003| [123,501,125] |
+-----------+-------------------+
df2 = (product_no: String , product_items_set:Array[String])
+-----------+-------------------+
| product_no| product_items_set |
+-----------+-------------------+
| 909| [123,124] |
| 908| [501,502,503] |
| 907| [123,501,125] |
+-----------+-------------------+
现在我想将 df1(Items_no_set) 与 df2(product_items_set) 进行比较,如果找到的匹配项返回 df3(Receipt_no,Items_no_set,product_no)。
如果在上述情况下找不到匹配项,我想创建 df1(Items_no_set) 的子集,然后比较是否找到匹配项
我的预期输出:
+-----------+-------------------+-----------+
| Receipt_no| Items_no_set | product_no|
+-----------+-------------------+-----------+
| 001| [123,124] | 909 |
| 002| [501,502,503,504] | 908 |
| 003| [123,501,125] | 907 |
+-----------+-------------------+-----------+
我正在努力实现上述步骤和我的预期输出。任何帮助将不胜感激。
【问题讨论】:
【参考方案1】:由于 df1 和 df2 之间没有匹配的 key,我们必须做 crossJoin。看看这个 rdd 解决方案:
scala> val df1 = Seq(
| ("001",Array(123,124,125)),
| ("002",Array(501,502,503,504)),
| ("003",Array(123,501,125)) ).toDF("receipt_no","items_no_set")
df1: org.apache.spark.sql.DataFrame = [receipt_no: string, items_no_set: array<int>]
scala> val df2 = Seq(
| ("909",Array(123,124)),
| ("908",Array(501,502,503)),
| ("907",Array(123,501,125)) ).toDF("product_no","product_items_set")
df2: org.apache.spark.sql.DataFrame = [product_no: string, product_items_set: array<int>]
scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._
scala> val df3 = df1.crossJoin(df2)
df3: org.apache.spark.sql.DataFrame = [receipt_no: string, items_no_set: array<int> ... 2 more fields]
scala> val rdd2= df3.rdd.filter( x =>
| val items = x.getAs[scala.collection.mutable.WrappedArray[Int]]("items_no_set").toArray;
| val prds = x.getAs[scala.collection.mutable.WrappedArray[Int]]("product_items_set").toArray;
| val chk = prds.intersect(items).length == prds.length
| ( chk == true )
|
| ).map( x => Row(x(0),x(1),x(2)))
rdd2: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = MapPartitionsRDD[76] at map at <console>:50
scala> val schema = df1.schema.add(StructField("product_no",StringType))
schema: org.apache.spark.sql.types.StructType = StructType(StructField(receipt_no,StringType,true), StructField(items_no_set,ArrayType(IntegerType,false),true), StructField(product_no,StringType,true))
scala> spark.createDataFrame(rdd2,schema).show(false)
+----------+--------------------+----------+
|receipt_no|items_no_set |product_no|
+----------+--------------------+----------+
|001 |[123, 124, 125] |909 |
|002 |[501, 502, 503, 504]|908 |
|003 |[123, 501, 125] |907 |
+----------+--------------------+----------+
scala>
【讨论】:
以上是关于将一个数据帧的数组列与scala中另一个数据帧的数组列的子集进行比较的主要内容,如果未能解决你的问题,请参考以下文章
将 MultiIndex Pandas 数据帧乘以来自另一个数据帧的多个标量
如何根据一个数据帧中的列值和R中另一个数据帧的列标题名称有条件地创建新列
在 Apache Spark (Scala) 上获取两个数据帧的差异