FlatMap 从列值到多行缺少架构
Posted
技术标签:
【中文标题】FlatMap 从列值到多行缺少架构【英文标题】:FlatMap from a column value to multiple rows missing schema 【发布时间】:2021-02-25 16:48:23 【问题描述】:我正在尝试根据列值将以下数据框转换为多行。我相信 Row 缺少架构(第 128 行)并引发异常。
原始数据框
+---+------------------------+
|Id |Set |
+---+------------------------+
|1 |AA001-AA003, BB002-BB003|
|2 |AA045-AA046, CC099-CC100|
+---+------------------------+
用于澄清目的的中间数据框步骤
+---+-----------+
| Id| Set|
+---+-----------+
| 1|AA001-AA003|
| 1|BB002-BB003|
| 2|AA045-AA046|
| 2|CC099-CC100|
+---+-----------+
最终数据帧
+---+-------+------+------+
| Id|Combine|Letter|Number|
+---+-------+------+------+
| 1| AA001| AA| 1|
| 1| AA002| AA| 2|
| 1| AA003| AA| 3|
| 1| BB002| BB| 2|
| 1| BB003| BB| 3|
| 2| AA045| AA| 45|
| 2| AA046| AA| 46|
| 2| CC099| CC| 99|
| 2| CC100| CC| 100|
+---+-------+------+------+
这是我得到异常的地方:
示例应用
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
public class SampleApp implements Serializable
private static final long serialVersionUID = -1L;
private static String ID = "Id";
private static String SET = "Set";
private static String COMBINE = "Combine";
private static String LETTER = "Letter";
private static String NUMBER = "Number";
public static void main(String[] args)
SampleApp app = new SampleApp();
app.start();
private void start()
Logger.getLogger("org.apache").setLevel(Level.WARN);
SparkSession spark = SparkSession
.builder()
.appName("Spark App")
.master("local[*]")
.getOrCreate();
StructType commaStructType = new StructType();
commaStructType = commaStructType.add(ID, DataTypes.IntegerType, false);
commaStructType = commaStructType.add(SET, DataTypes.StringType, true);
StructType resultStructType = new StructType();
resultStructType = resultStructType.add(ID, DataTypes.IntegerType, false);
resultStructType = resultStructType.add(COMBINE, DataTypes.StringType, false);
resultStructType = resultStructType.add(LETTER, DataTypes.StringType, false);
resultStructType = resultStructType.add(NUMBER, DataTypes.IntegerType, false);
List<Row> list = new ArrayList<Row>();
list.add(RowFactory.create(1, "AA001-AA003, BB002-BB003"));
list.add(RowFactory.create(2, "AA045-AA046, CC099-CC100"));
Dataset<Row> df = spark.createDataFrame(list, commaStructType);
df.show(10, false);
df.printSchema();
Dataset<Row> commaSeparatedDf = df.flatMap(new separateByCommaFlatMap(), RowEncoder.apply(commaStructType));
commaSeparatedDf.show(10, true);
commaSeparatedDf.printSchema();
Dataset<Row> resultDf = commaSeparatedDf.flatMap(new separateByDashFlatMap(), RowEncoder.apply(resultStructType));
resultDf.show(10, true);
resultDf.printSchema();
/* This manually created DataFrame for the final step works */
/*List<Row> list2 = new ArrayList<Row>();
list2.add(RowFactory.create(1, "AA001-AA003"));
list2.add(RowFactory.create(1, "BB002-BB003"));
list2.add(RowFactory.create(2, "AA045-AA046"));
list2.add(RowFactory.create(2, "CC099-CC100"));
Dataset<Row> df2 = spark.createDataFrame(list2, commaStructType);
df2.show(10, true);
df2.printSchema();
Dataset<Row> resultDf2 = df2.flatMap(new separateByDashFlatMap(), RowEncoder.apply(resultStructType));
resultDf2.show(10, true);
resultDf2.printSchema();*/
/*
* Split "AA001-AA003, BB002-BB003" into individual row
* AA001-AA003
* BB002-BB003
*/
private final class separateByCommaFlatMap implements FlatMapFunction<Row, Row>
private static final long serialVersionUID = 63784L;
@Override
public Iterator<Row> call(Row r) throws Exception
int id = Integer.parseInt(r.getAs(ID).toString());
String[] s = r.getAs(SET).toString().split(", ");
List<Row> list = new ArrayList<Row>();
for (int i = 0; i < s.length; i++)
List<Object> data = new ArrayList<>();
data.add(id);
data.add(s[i]);
list.add(RowFactory.create(data.toArray()));
return list.iterator();
/*
* Split "AA001-AA003" into individual row
* AA001 | AA | 1
* AA002 | AA | 2
* AA003 | AA | 3
*/
private final class separateByDashFlatMap implements FlatMapFunction<Row, Row>
private static final long serialVersionUID = 63784L;
@Override
public Iterator<Row> call(Row r) throws Exception
int id = r.getAs(ID);
String[] s = r.getAs(SET).toString().split("-");
String letter = s[0].substring(0, 2);
int start = Integer.parseInt(s[0].substring(2, s[0].length()));
int end = Integer.parseInt(s[1].substring(2, s[1].length()));
List<Row> list = new ArrayList<Row>();
for(int i = start; i <= end; i++)
List<Object> data = new ArrayList<>();
data.add(id);
data.add(String.format("%s%03d", letter, i));
data.add(letter);
data.add(i);
list.add(RowFactory.create(data.toArray()));
return list.iterator();
【问题讨论】:
你需要使用flatMap
吗?你会考虑使用数据框 API 吗?
@mck 我不需要 flatMap。我愿意接受任何建议。
好的,你的 spark 版本是什么?
Spark 版本 3.0.1
如果您遇到异常,请在您的问题中包含其堆栈跟踪。
【参考方案1】:
这是一个基于数据框 API 的解决方案:
import org.apache.spark.sql.functions._
Dataset<Row> result = df.withColumn(
"Set",
explode(split(col("Set"), ", ")) // split by comma and explode into rows
).withColumn(
"Letter",
substring(col("Set"), 1, 2) // get letter from first two chars
).withColumn(
"Number", // get and explode a list of numbers using Spark SQL sequence function
expr("""
explode(sequence(
int(substring(split(Set, '-')[0], 3)),
int(substring(split(Set, '-')[1], 3))
))
""")
).withColumn(
"Combine", // get formatted string for combine column
format_string("%s%03d", col("Letter"), col("Number"))
).select(
"ID", "Combine", "Letter", "Number"
)
result.show()
+---+-------+------+------+
| ID|Combine|Letter|Number|
+---+-------+------+------+
| 1| AA001| AA| 1|
| 1| AA002| AA| 2|
| 1| AA003| AA| 3|
| 1| BB002| BB| 2|
| 1| BB003| BB| 3|
| 2| AA045| AA| 45|
| 2| AA046| AA| 46|
| 2| CC099| CC| 99|
| 2| CC100| CC| 100|
+---+-------+------+------+
【讨论】:
我在scala上测试过,但是API差不多,希望java也能用。【参考方案2】:如果有人想要 JAVA 代码,我还添加了我的答案。特别感谢 mck!
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.explode;
import static org.apache.spark.sql.functions.split;
import static org.apache.spark.sql.functions.substring;
import static org.apache.spark.sql.functions.sequence;
import static org.apache.spark.sql.functions.format_string;
import static org.apache.spark.sql.functions.col;
public class ExplodeApp implements Serializable
private static final long serialVersionUID = -1L;
private static String ID = "Id";
private static String SET = "Set";
private static String COMBINE = "Combine";
private static String LETTER = "Letter";
private static String NUMBER = "Number";
private static String RANGE = "Range";
public static void main(String[] args)
ExplodeApp app = new ExplodeApp();
app.start();
private void start()
Logger.getLogger("org.apache").setLevel(Level.WARN);
SparkSession spark = SparkSession
.builder()
.appName("Spark App")
.master("local[*]")
.getOrCreate();
StructType commaStructType = new StructType();
commaStructType = commaStructType.add(ID, DataTypes.IntegerType, false);
commaStructType = commaStructType.add(SET, DataTypes.StringType, true);
List<Row> list = new ArrayList<Row>();
list.add(RowFactory.create(1, "AA001-AA003, BB002-BB003"));
list.add(RowFactory.create(2, "AA045-AA046, CC099-CC100"));
Dataset<Row> df = spark.createDataFrame(list, commaStructType);
df.show(10, false);
Column[] columnNames = new Column[] col(ID), col(COMBINE), col(LETTER), col(NUMBER) ;
Dataset<Row> resultDf = df
.withColumn(RANGE, explode(split(df.col(SET), ", ")))
.withColumn(LETTER, substring(col(RANGE), 1, 2))
.withColumn(NUMBER, explode(
sequence(substring(split(col(RANGE), "-").getItem(0), 3, 3).cast(DataTypes.IntegerType),
substring(split(col(RANGE), "-").getItem(1), 3, 3).cast(DataTypes.IntegerType))))
.withColumn(COMBINE, format_string("%s%03d", col(LETTER), col(NUMBER)))
.select(columnNames);
resultDf.show(10, false);
【讨论】:
以上是关于FlatMap 从列值到多行缺少架构的主要内容,如果未能解决你的问题,请参考以下文章