Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?

Posted

技术标签:

【中文标题】Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?【英文标题】:Spark/Scala 1.6 how to use dataframe groupby agg to implement following logical? 【发布时间】:2017-04-11 11:38:11 【问题描述】:

如何使用dataframe groupby agg实现以下逻辑?

ID    ID2   C1   C2   C3   C4   C5   C6 .....C33
CM1    a    1    1    1    0    0    0
CM2    a    1    1    0    1    0    0
CM3    a    1    0    1    1    1    0
CM4    a    1    1    1    1    1    0
CM5    a    1    1    1    1    1    0
1k2    b    0    0    1    1    1    0
1K3    b    1    1    1    1    1    0
1K1    b    0    0    0    0    1    0

我希望我的输出 df 看起来像这样

ID    ID2   C1   C2   C3   C4   C5   C6 .....C33
CM1    a    1    1    1    0    0    0
CM2    a    0    0    0    1    0    0
CM3    a    0    0    0    0    1    0
CM4    a    0    0    0    0    0    0
CM5    a    0    0    0    0    0    0
1K1    b    0    0    0    0    1    0
1k2    b    0    0    1    1    0    0
1K3    b    1    1    0    0    0    0

逻辑是根据ID2做group by,然后找到Cn为1时的最小ID然后置1,其他置0。

Cn 最高可达 C33。

如果用例类将超过限制。

我尝试过使用 mapPartitions

但是结果错了……

使用 Spark 1.6.0


添加我尝试过的代码

case class testGoods(ID: String, ID2: String, C1 : String, C2 : String)

val cartMap = new HashMap[String, Set[(String,String,String)]] with MultiMap[String,(String,String,String)]

val baseDF=hiveContext.sql(newSql)

val testRDD=baseDF.mapPartitions( partition => 
  while (partition.hasNext) 
    val record = partition.next()
    val ID = record.getString(0)
    if (ID != null && ID != "null") 
      val ID2=record.getString(1)
      val C1=record.getString(2)
      val C2=record.getString(3)
      cartMap.addBinding(ID2, (ID,C1,C2))
    
  
  cartMap.iterator
)

val recordList = new mutable.ListBuffer[testGoods]()
val testRDD1=testRDD.mapPartitions( partition => 
  while (partition.hasNext) 
    val record = partition.next()
    val ID2=record._1
    val recordRow= record._2
    val sortedRecordRow = TreeSet[(String,String,String)]() ++ recordRow
    val dic=new mutable.HashMap[String,String]


    for(v<-sortedRecordRow) 
      val ID = v._1
      val C1 = v._2
      val C2 = v._3

      if (dic.contains(ID2))
        val goodsValue=dic.get(ID2)
        if("1".equals(goodsValue))
          recordList.append(new testGoods(ID, ID2, "0", C2))
        else
          dic.put(ID2,C1)
          recordList.append(new testGoods(ID, ID2, C1,C2))
        
      else
        dic.put(ID2,C1)
        recordList.append(new testGoods(ID, ID2, C1, C2))
      
    
  
  recordList.iterator
)

再次编辑

原始数据集有数百万个ID,按ID分组后,每个ID2可能有2~300个数据。

【问题讨论】:

你能添加一些你试过的代码吗? 当然,我添加了一些代码并编辑了示例。 【参考方案1】:

基本上,如果您的子表很小(正如您在评论中提到的 2~300 个数据点),您所需要的只是这样的:

val columnIds = List(2, 3, 4, 5, 6, 7)// (preColumns to preColumns + numColumns).toList
val columnNames = List("C1", "C2", "C3", "C4", "C5", "C6") //just for representability

val withKey = df.rdd.map(c => c.getString(1) -> c).groupByKey

val res = withKey.flatMap 
  case (id, local) =>
    case class Accumulator(found: Set[Int] = Set.empty, result: List[Row] = List.empty)    
    local.foldLeft(Accumulator())
      case (acc, row) => 
        val found = columnIds.filter(id => row.getInt(id) != 0) //columns with `1`
        val pre = Seq(row(0), row(1))
        val res = pre ++ columnIds.map cid => 
          if (acc.found.contains(cid)) 0 else row.get(cid)
        
        Accumulator(acc.found ++ found, Row.fromSeq(res) :: acc.result)
    .result.reverse

您可以通过使用比List 更合适的集合来删除reverse 作为累加器(例如Queue) 引入Accumulator 是为了在表示我们已经找到“1”的列时避免可变变量。但是,您可以在 Spark 中使用 var 的内部 lambda(但不能在外部!) - 它也可以工作。 Accumulator 将处理后的列保存在 found 中,并将结果本身保存在 result 中。 使用 groupByKey 是因为您无法在 Spark 中的 rdd 中访问 rdd,您只能在流中访问。它也是代码中 mapPartitions 的“更智能”替代品 - 它可以理解“按键分组”而不是分区。 local 是数据框的本地子集,您基本上可以将其作为常规的Scala-collection 操作。

RDD 转换回DataFrame

val newDf = sqlContext.createDataFrame(res, df.schema)

实验:

data.csv:
ID,ID2,C1,C2,C3,C4,C5,C6
CM1,a,1,1,1,0,0,0
CM2,a,1,1,0,1,0,0
CM3,a,1,0,1,1,1,0
1k2,b,0,0,1,1,1,0
1K3,b,1,1,1,1,1,0
1K1,b,0,0,0,0,1,0

val sqlContext = new SQLContext(sc)

val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")
...
res.collect()

res77_4: Array[Row] = Array(
[CM1,a,1,1,1,0,0,0],
  [CM2,a,0,0,0,1,0,0],
  [CM3,a,0,0,0,0,1,0],
  [1k2,b,0,0,1,1,1,0],
  [1K3,b,1,1,0,0,0,0],
  [1K1,b,0,0,0,0,0,0]
)

对于处理稀疏数据(有很多零)的其他情况 - 考虑使用 Mllib 的矩阵:https://spark.apache.org/docs/2.1.0/mllib-data-types.html#distributed-matrix。它们可以更有效地保存和处理稀疏结构。

我还建议使用 case-class 表示 Row 以避免使用列索引。

【讨论】:

太棒了!谢谢! 另一个问题,如果我的子表很大,我将如何更改代码?我发现我的一个子表有 30000 个数据点,可能会抛出错误:Listener SQLListener throw an exception:java.lang.NullPointerException,但我不确定。 原始数据集大约20G,有250列。所以我改变了一些代码。 val columnIds =(213 to 245).toList val withKey = baseDF.rdd.map(c => c.getString(129) -> c).groupByKey val pre = (0 to 212).map(i=>row( i)) @user3127157 30000 并不算多,IMO。无论如何,如果您确实面临可伸缩性问题,有几个选项,但只要您的数据集按 id2 排序,您就可以引入一些子键并存储每个子键(甚至子键中的每个记录)的累加器标志(“找到”)并在第二阶段合并子项。【参考方案2】:

这是我解决此问题的算法。绝对有很大的改进空间。 基本上我建议分两个阶段解决。

    为数据集中的每个第一个 1 构建 Id 映射。

    根据此地图绘制完整数据框 - 将所有不满意的记录替换为 0

val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")

val Limit = 5 // or 33

val re = (1 to Limit).map  i =>
    val col = "C" + i
    val first = df.filter(df(col) > 0).head
    val id  = first(0) // id column name
    i -> id
.toMap
// result of this stage is smth like:Map(5 -> CM3, 1 -> CM1, 2 -> CM1, 3 -> CM1, 4 -> CM2)    

df.map(columns => 
     Seq(columns(0), columns(1)) ++ 
        (1 to Limit).map  x => 
             if (columns(0) == re(x)) 1 else 0 
      ).foreach(println)

【讨论】:

感谢您的回答,我为 0 种情况编辑了示例,导致 java.util.NoSuchElementException: next on empty iterator 您似乎忽略了ID2-grouping。据我了解,目标是应用此逻辑ID2-wise @dk14 在原始查询中没有关于 id2 分组的内容。但如果用户理解概念,就很容易应用另一个分组层。 @user3127157 - 很可能是因为您的列中有零而没有任何 1。尝试调试建议的解决方案。检查re的输出是什么 在不放弃可扩展性的情况下,在 spark-sql(甚至 rdd)中应用分组层并不容易。所以@user3127157 的问题是表中有多少个“a”

以上是关于Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?的主要内容,如果未能解决你的问题,请参考以下文章

如何在Ubuntu下搭建Spark集群

用scala在spark中读取压缩文件

如何使用 spark/scala 解析 YAML

如何在 Spark/Scala 中使用 countDistinct?

spark Scala 数据框选择

在 Spark Scala 中解码 Base64