Spark 数据集案例类编码器的小数精度

Posted

技术标签:

【中文标题】Spark 数据集案例类编码器的小数精度【英文标题】:Decimal precision for Spark Dataset case class Encoder 【发布时间】:2018-12-07 10:38:59 【问题描述】:

我有 CSV 数据:

"id","price"
"1","79.07"
"2","91.27"
"3","85.6"

使用SparkSession阅读它:

def readToDs(resource: String, schema: StructType): Dataset = 
    sparkSession.read
      .option("header", "true")
      .schema(schema)
      .csv(resource)
      .as[ItemPrice]

案例分类:

case class ItemPrice(id: Long, price: BigDecimal)

打印数据集:

def main(args: Array[String]): Unit = 
    val prices: Dataset = 
        readToDs("src/main/resources/app/data.csv", Encoders.product[ItemPrice].schema);
    prices.show();

输出:

+----------+--------------------+
|        id|               price|
+----------+--------------------+
|         1|79.07000000000000...|
|         2|91.27000000000000...|
|         3|85.60000000000000...|
+----------+--------------------+

期望的输出:

+----------+--------+
|        id|   price|
+----------+--------+
|         1|   79.07|
|         2|   91.27|
|         3|   85.6 |
+----------+--------+

我已经知道的选项:

使用硬编码的列顺序和数据类型手动定义架构,例如:

def defineSchema(): StructType =
    StructType(
      Seq(StructField("id", LongType, nullable = false)) :+
        StructField("price", DecimalType(3, 2), nullable = false)
    )

并像这样使用它:

val prices: Dataset = readToDs("src/main/resources/app/data.csv", defineSchema);

如何在不手动定义所有结构的情况下设置精度(3,2)

【问题讨论】:

您能否检查一下答案中的 UPDATE1 【参考方案1】:

假设您将 csv 设置为

scala> val df = Seq(("1","79.07","89.04"),("2","91.27","1.02"),("3","85.6","10.01")).toDF("item","price1","price2")
df: org.apache.spark.sql.DataFrame = [item: string, price1: string ... 1 more field]

scala> df.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: string (nullable = true)
 |-- price2: string (nullable = true)

你可以像下面这样投射它

scala> val df2 = df.withColumn("price1",'price1.cast(DecimalType(4,2)))
df2: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]

scala> df2.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: string (nullable = true)


scala>

现在,如果您知道 csv.. 中的十进制列列表,您可以像下面这样动态地进行操作

scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._

scala> val decimal_cols = Array("price1","price2")
decimal_cols: Array[String] = Array(price1, price2)

scala> val df3 = decimal_cols.foldLeft(df) (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) 
df3: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]

scala> df3.show
+----+------+------+
|item|price1|price2|
+----+------+------+
|   1| 79.07| 89.04|
|   2| 91.27|  1.02|
|   3| 85.60| 10.01|
+----+------+------+


scala> df3.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: decimal(4,2) (nullable = true)


scala>

这有帮助吗?

更新1:

使用 inferSchema 读取 csv 文件,然后将所有双字段动态转换为 DecimalType(4,2)。

val df = spark.read.format("csv").option("header","true").option("inferSchema","true").load("in/items.csv")
df.show
df.printSchema()
val decimal_cols = df.schema.filter( x=> x.dataType.toString == "DoubleType" ).map(x=>x.name)
// or df.schema.filter( x=> x.dataType==DoubleType )
val df3 = decimal_cols.foldLeft(df) (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) 
df3.printSchema()
df3.show()

结果:

+-----+------+------+
|items|price1|price2|
+-----+------+------+
|    1| 79.07| 89.04|
|    2| 91.27|  1.02|
|    3|  85.6| 10.01|
+-----+------+------+

root
 |-- items: integer (nullable = true)
 |-- price1: double (nullable = true)
 |-- price2: double (nullable = true)

root
 |-- items: integer (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: decimal(4,2) (nullable = true)

+-----+------+------+
|items|price1|price2|
+-----+------+------+
|    1| 79.07| 89.04|
|    2| 91.27|  1.02|
|    3| 85.60| 10.01|
+-----+------+------+

【讨论】:

【参考方案2】:

一个选项是为输入模式定义一个转换器:

def defineDecimalType(schema: StructType): StructType = 
    new StructType(
      schema.map 
        case StructField(name, dataType, nullable, metadata) =>
          if (dataType.isInstanceOf[DecimalType])
            // Pay attention to max precision in the source data
            StructField(name, new DecimalType(20, 2), nullable, metadata)
          else 
            StructField(name, dataType, nullable, metadata)
      .toArray
    )
 

def main(args: Array[String]): Unit = 
    val prices: Dataset = 
        readToDs("src/main/resources/app/data.csv", defineDecimalType(Encoders.product[ItemPrice].schema));
    prices.show();

这种方法的缺点是这种映射适用于每一列,如果你有一个不适合精确精度的ID(比如说ID = 10000DecimalType(3, 2))你会得到一个例外:

原因:java.lang.IllegalArgumentException:要求失败: 小数精度 4 超过最大精度 3 scala.Predef$.require(Predef.scala:224) 在 org.apache.spark.sql.types.Decimal.set(Decimal.scala:113) 在 org.apache.spark.sql.types.Decimal$.apply(Decimal.scala:426) 在 org.apache.spark.sql.execution.datasources.csv.CSVTypeCast$.castTo(CSVInferSchema.scala:27​​3) 在 org.apache.spark.sql.execution.datasources.csv.CSVRelation$$anonfun$csvParser$3.apply(CSVRelation.scala:125) 在 org.apache.spark.sql.execution.datasources.csv.CSVRelation$$anonfun$csvParser$3.apply(CSVRelation.scala:94) 在 org.apache.spark.sql.execution.datasources.csv.CSVFileFormat$$anonfun$buildReader$1$$anonfun$apply$2.apply(CSVFileFormat.scala:167) 在 org.apache.spark.sql.execution.datasources.csv.CSVFileFormat$$anonfun$buildReader$1$$anonfun$apply$2.apply(CSVFileFormat.scala:166)

这就是为什么在源数据中保持精度高于最大小数点很重要的原因:

if (dataType.isInstanceOf[DecimalType])
    StructField(name, new DecimalType(20, 2), nullable, metadata)

【讨论】:

【参考方案3】:

我尝试使用 2 个不同的 CSV 文件加载示例数据,它工作正常,结果与以下代码的预期一致。 我在 Windows 上使用 Spark 2.3.1。

//read with double quotes
val df1 = spark.read
.format("csv")
.option("header","true")
.option("inferSchema","true")
.option("nullValue","")
.option("mode","failfast")
.option("path","D:/bitbuket/spark-examples/53667822/string.csv")
.load()

df1.show
/*
scala> df1.show
+---+-----+
| id|price|
+---+-----+
|  1|79.07|
|  2|91.27|
|  3| 85.6|
+---+-----+
*/

//read with without quotes
val df2 = spark.read
.format("csv")
.option("header","true")
.option("inferSchema","true")
.option("nullValue","")
.option("mode","failfast")
.option("path","D:/bitbuket/spark-examples/53667822/int-double.csv")
.load()

df2.show

/*
scala> df2.show
+---+-----+
| id|price|
+---+-----+
|  1|79.07|
|  2|91.27|
|  3| 85.6|
+---+-----+
*/

【讨论】:

这肯定会使用DataFrame 提供原样输出,因为它在没有显式转换的情况下以String 对数据进行操作。但这是 Dataset 的另一个故事。

以上是关于Spark 数据集案例类编码器的小数精度的主要内容,如果未能解决你的问题,请参考以下文章

Spark 案例类 - 十进制类型编码器错误“无法从十进制向上转换”

具有特征的 Spark 2.0 数据集编码器

在 Apache Spark Dataset<Row> 上应用 flatMap 操作时出现意外的编码器行为

Scala 集合的数据集编码器

Spark 数据集:示例:无法生成编码器问题

Spark 中的数组数据集 (1.6.1)