Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?
Posted
技术标签:
【中文标题】Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?【英文标题】:Spark/Scala 1.6 how to use dataframe groupby agg to implement following logical? 【发布时间】:2017-04-11 11:38:11 【问题描述】:如何使用dataframe groupby agg实现以下逻辑?
ID ID2 C1 C2 C3 C4 C5 C6 .....C33
CM1 a 1 1 1 0 0 0
CM2 a 1 1 0 1 0 0
CM3 a 1 0 1 1 1 0
CM4 a 1 1 1 1 1 0
CM5 a 1 1 1 1 1 0
1k2 b 0 0 1 1 1 0
1K3 b 1 1 1 1 1 0
1K1 b 0 0 0 0 1 0
我希望我的输出 df 看起来像这样
ID ID2 C1 C2 C3 C4 C5 C6 .....C33
CM1 a 1 1 1 0 0 0
CM2 a 0 0 0 1 0 0
CM3 a 0 0 0 0 1 0
CM4 a 0 0 0 0 0 0
CM5 a 0 0 0 0 0 0
1K1 b 0 0 0 0 1 0
1k2 b 0 0 1 1 0 0
1K3 b 1 1 0 0 0 0
逻辑是根据ID2做group by,然后找到Cn为1时的最小ID然后置1,其他置0。
Cn 最高可达 C33。
如果用例类将超过限制。
我尝试过使用 mapPartitions
但是结果错了……
使用 Spark 1.6.0
添加我尝试过的代码
case class testGoods(ID: String, ID2: String, C1 : String, C2 : String)
val cartMap = new HashMap[String, Set[(String,String,String)]] with MultiMap[String,(String,String,String)]
val baseDF=hiveContext.sql(newSql)
val testRDD=baseDF.mapPartitions( partition =>
while (partition.hasNext)
val record = partition.next()
val ID = record.getString(0)
if (ID != null && ID != "null")
val ID2=record.getString(1)
val C1=record.getString(2)
val C2=record.getString(3)
cartMap.addBinding(ID2, (ID,C1,C2))
cartMap.iterator
)
val recordList = new mutable.ListBuffer[testGoods]()
val testRDD1=testRDD.mapPartitions( partition =>
while (partition.hasNext)
val record = partition.next()
val ID2=record._1
val recordRow= record._2
val sortedRecordRow = TreeSet[(String,String,String)]() ++ recordRow
val dic=new mutable.HashMap[String,String]
for(v<-sortedRecordRow)
val ID = v._1
val C1 = v._2
val C2 = v._3
if (dic.contains(ID2))
val goodsValue=dic.get(ID2)
if("1".equals(goodsValue))
recordList.append(new testGoods(ID, ID2, "0", C2))
else
dic.put(ID2,C1)
recordList.append(new testGoods(ID, ID2, C1,C2))
else
dic.put(ID2,C1)
recordList.append(new testGoods(ID, ID2, C1, C2))
recordList.iterator
)
再次编辑
原始数据集有数百万个ID,按ID分组后,每个ID2可能有2~300个数据。
【问题讨论】:
你能添加一些你试过的代码吗? 当然,我添加了一些代码并编辑了示例。 【参考方案1】:基本上,如果您的子表很小(正如您在评论中提到的 2~300 个数据点),您所需要的只是这样的:
val columnIds = List(2, 3, 4, 5, 6, 7)// (preColumns to preColumns + numColumns).toList
val columnNames = List("C1", "C2", "C3", "C4", "C5", "C6") //just for representability
val withKey = df.rdd.map(c => c.getString(1) -> c).groupByKey
val res = withKey.flatMap
case (id, local) =>
case class Accumulator(found: Set[Int] = Set.empty, result: List[Row] = List.empty)
local.foldLeft(Accumulator())
case (acc, row) =>
val found = columnIds.filter(id => row.getInt(id) != 0) //columns with `1`
val pre = Seq(row(0), row(1))
val res = pre ++ columnIds.map cid =>
if (acc.found.contains(cid)) 0 else row.get(cid)
Accumulator(acc.found ++ found, Row.fromSeq(res) :: acc.result)
.result.reverse
您可以通过使用比List
更合适的集合来删除reverse
作为累加器(例如Queue
)
引入Accumulator
是为了在表示我们已经找到“1”的列时避免可变变量。但是,您可以在 Spark 中使用 var
的内部 lambda(但不能在外部!) - 它也可以工作。
Accumulator
将处理后的列保存在 found
中,并将结果本身保存在 result
中。
使用 groupByKey
是因为您无法在 Spark 中的 rdd 中访问 rdd,您只能在流中访问。它也是代码中 mapPartitions
的“更智能”替代品 - 它可以理解“按键分组”而不是分区。
local
是数据框的本地子集,您基本上可以将其作为常规的Scala
-collection 操作。
将RDD
转换回DataFrame
:
val newDf = sqlContext.createDataFrame(res, df.schema)
实验:
data.csv:
ID,ID2,C1,C2,C3,C4,C5,C6
CM1,a,1,1,1,0,0,0
CM2,a,1,1,0,1,0,0
CM3,a,1,0,1,1,1,0
1k2,b,0,0,1,1,1,0
1K3,b,1,1,1,1,1,0
1K1,b,0,0,0,0,1,0
val sqlContext = new SQLContext(sc)
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")
...
res.collect()
res77_4: Array[Row] = Array(
[CM1,a,1,1,1,0,0,0],
[CM2,a,0,0,0,1,0,0],
[CM3,a,0,0,0,0,1,0],
[1k2,b,0,0,1,1,1,0],
[1K3,b,1,1,0,0,0,0],
[1K1,b,0,0,0,0,0,0]
)
对于处理稀疏数据(有很多零)的其他情况 - 考虑使用 Mllib 的矩阵:https://spark.apache.org/docs/2.1.0/mllib-data-types.html#distributed-matrix。它们可以更有效地保存和处理稀疏结构。
我还建议使用 case-class 表示 Row
以避免使用列索引。
【讨论】:
太棒了!谢谢! 另一个问题,如果我的子表很大,我将如何更改代码?我发现我的一个子表有 30000 个数据点,可能会抛出错误:Listener SQLListener throw an exception:java.lang.NullPointerException,但我不确定。 原始数据集大约20G,有250列。所以我改变了一些代码。 val columnIds =(213 to 245).toList val withKey = baseDF.rdd.map(c => c.getString(129) -> c).groupByKey val pre = (0 to 212).map(i=>row( i)) @user3127157 30000 并不算多,IMO。无论如何,如果您确实面临可伸缩性问题,有几个选项,但只要您的数据集按 id2 排序,您就可以引入一些子键并存储每个子键(甚至子键中的每个记录)的累加器标志(“找到”)并在第二阶段合并子项。【参考方案2】:这是我解决此问题的算法。绝对有很大的改进空间。 基本上我建议分两个阶段解决。
为数据集中的每个第一个 1
构建 Id 映射。
根据此地图绘制完整数据框 - 将所有不满意的记录替换为 0
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")
val Limit = 5 // or 33
val re = (1 to Limit).map i =>
val col = "C" + i
val first = df.filter(df(col) > 0).head
val id = first(0) // id column name
i -> id
.toMap
// result of this stage is smth like:Map(5 -> CM3, 1 -> CM1, 2 -> CM1, 3 -> CM1, 4 -> CM2)
df.map(columns =>
Seq(columns(0), columns(1)) ++
(1 to Limit).map x =>
if (columns(0) == re(x)) 1 else 0
).foreach(println)
【讨论】:
感谢您的回答,我为 0 种情况编辑了示例,导致 java.util.NoSuchElementException: next on empty iterator 您似乎忽略了ID2
-grouping。据我了解,目标是应用此逻辑ID2
-wise
@dk14 在原始查询中没有关于 id2 分组的内容。但如果用户理解概念,就很容易应用另一个分组层。
@user3127157 - 很可能是因为您的列中有零而没有任何 1。尝试调试建议的解决方案。检查re
的输出是什么
在不放弃可扩展性的情况下,在 spark-sql(甚至 rdd)中应用分组层并不容易。所以@user3127157 的问题是表中有多少个“a”以上是关于Spark/Scala 1.6 如何使用 dataframe groupby agg 来实现以下逻辑?的主要内容,如果未能解决你的问题,请参考以下文章