scala实战之spark源码修改(能够将DataFrame按字段增量写入mysql数据表)

Posted zfszhangyuan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了scala实战之spark源码修改(能够将DataFrame按字段增量写入mysql数据表)相关的知识,希望对你有一定的参考价值。

在上一篇博文中,我们可以简单的应用官网的给出的一些接口提取mysql数据表中的数据到spark中,也可以将spark的运行结果存入mysql中。

但是我们会发现spark将其DF存入mysql的时候,无论你选择什么模式:

jdbcDF.write.mode(SaveMode.Overwrite).jdbc(url,"zfs_test",prop)
jdbcDF.write.mode(SaveMode.Append).jdbc(url,"zbh_test",prop)
结果都是会重建这个表。

这样一来这个表之前的数据就不存在了,而且如果我这个表还有其他字段(比如我有一个自增的主键id),那就没辙了。

本文所有的环境同http://blog.csdn.net/zfszhangyuan/article/details/52593521

spark版本是1.5.2,这次我们需要从官网下载spark的源码http://www.apache.org/dist/spark/spark-1.5.2/

选择spark-1.5.2.tgz下载

原先项目中添加源码

我们跟一下源码,看看到底什么原因导致,无论我设置什么模式,结果都是删除表,重建,再存入数据

最终的原因是:

mode被写死了,前面你无论设置的是append也好其他也好,最终都是Overwrite。

另外spark在插入数据到mysql的方法也不是很好如下:


他是直接 insert into table values(...); 这样做就要求插入的表的字段名称和顺序都必须和DF中的数据完全一致才能成功。当我们想将DF的数据插入到mysql表指定字段的时候这个方法是做不到的。

既然问题原因找到了,下面就开始我们的源码的优化吧

主要修改了insertStatement算法,JDBC方法添加DF:DataFrame参数 ,savemode的默认值

为了避免影响源码,我们重新继承Logging类重写JdbcUtils类 代码如下:

package JDBC_MySql

import java.sql.Connection, PreparedStatement
import java.util.Properties

//import com.besttone.utils.JDBCRDD, JdbcDialects
import org.apache.spark.Logging
import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame, Row, SaveMode

import scala.util.Try

/**
  * Util functions for JDBC tables.
  */
object JdbcUtils extends Logging 

  val  mode = SaveMode.Append


  def jdbc(url: String,df: DataFrame, table: String, connectionProperties: Properties): Unit = 
    val props = new Properties()
    props.putAll(connectionProperties)
    val conn = JdbcUtils.createConnection(url, props)

    try 
      var tableExists = JdbcUtils.tableExists(conn, table)

      if (mode == SaveMode.Ignore && tableExists) 
        return
      

      if (mode == SaveMode.ErrorIfExists && tableExists) 
        sys.error(s"Table $table already exists.")
      

      if (mode == SaveMode.Overwrite && tableExists) 
        JdbcUtils.dropTable(conn, table)
        tableExists = false
      

      // Create the table if the table didn't exist.
      if (!tableExists) 
        val schema = JdbcUtils.schemaString(df, url)
        val sql = s"CREATE TABLE $table ($schema)"
        conn.prepareStatement(sql).executeUpdate()
      
     finally 
      conn.close()
    

    JdbcUtils.saveTable(df, url, table, props)
  

  /**
    * Establishes a JDBC connection.
    */
  def createConnection(url: String, connectionProperties: Properties): Connection = 
    JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
  

  /**
    * Returns true if the table already exists in the JDBC database.
    */
  def tableExists(conn: Connection, table: String): Boolean = 
    // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
    // SQL database systems, considering "table" could also include the database name.
    Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
  

  /**
    * Drops a table from the JDBC database.
    */
  def dropTable(conn: Connection, table: String): Unit = 
    conn.prepareStatement(s"DROP TABLE $table").executeUpdate()
  

  /**
    * Returns a PreparedStatement that inserts a row into table via conn.
    */
  def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = 
    val fields = rddSchema.fields
    val fieldsSql = new StringBuilder(s"(")
    var i=0;
    for(f <- fields)
      fieldsSql.append(f.name)

      if(i==fields.length-1)
        fieldsSql.append(")")
      else
        fieldsSql.append(",")
      
      i+=1
    

    val sql = new StringBuilder(s"INSERT INTO $table ")
    sql.append(fieldsSql.toString())
    sql.append(" VALUES (")
    var fieldsLeft = rddSchema.fields.length
    while (fieldsLeft > 0) 
      sql.append("?")
      if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
      fieldsLeft = fieldsLeft - 1
    
    //println(sql.toString())
    conn.prepareStatement(sql.toString())
  

  /**
    * Saves a partition of a DataFrame to the JDBC database.  This is done in
    * a single database transaction in order to avoid repeatedly inserting
    * data as much as possible.
    *
    * It is still theoretically possible for rows in a DataFrame to be
    * inserted into the database more than once if a stage somehow fails after
    * the commit occurs but before the stage can return successfully.
    *
    * This is not a closure inside saveTable() because apparently cosmetic
    * implementation changes elsewhere might easily render such a closure
    * non-Serializable.  Instead, we explicitly close over all variables that
    * are used.
    */
  def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     nullTypes: Array[Int]): Iterator[Byte] = 
    val conn = getConnection()
    var committed = false
    try 
      conn.setAutoCommit(false) // Everything in the same db transaction.
      val stmt = insertStatement(conn, table, rddSchema)
      try 
        while (iterator.hasNext) 
          val row = iterator.next()
          val numFields = rddSchema.fields.length
          var i = 0
          while (i < numFields) 
            if (row.isNullAt(i)) 
              stmt.setNull(i + 1, nullTypes(i))
             else 
              rddSchema.fields(i).dataType match 
                case IntegerType => stmt.setInt(i + 1, row.getInt(i))
                case LongType => stmt.setLong(i + 1, row.getLong(i))
                case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
                case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
                case ShortType => stmt.setInt(i + 1, row.getShort(i))
                case ByteType => stmt.setInt(i + 1, row.getByte(i))
                case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
                case StringType => stmt.setString(i + 1, row.getString(i))
                case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
                case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
                case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
                case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
                case _ => throw new IllegalArgumentException(
                  s"Can't translate non-null value for field $i")
              
            
            i = i + 1
          
          stmt.executeUpdate()
        
       finally 
        stmt.close()
      
      conn.commit()
      committed = true
     finally 
      if (!committed) 
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        conn.rollback()
        conn.close()
       else 
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try 
          conn.close()
         catch 
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        
      
    
    Array[Byte]().iterator
  
  /**
    * Compute the schema string for this RDD.
    */
  def schemaString(df: DataFrame, url: String): String = 
    val sb = new StringBuilder()
    val dialect = JdbcDialects.get(url)
    df.schema.fields foreach  field => 
      val name = field.name
      val typ: String =
        dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
          field.dataType match 
            case IntegerType => "INTEGER"
            case LongType => "BIGINT"
            case DoubleType => "DOUBLE PRECISION"
            case FloatType => "REAL"
            case ShortType => "INTEGER"
            case ByteType => "BYTE"
            case BooleanType => "BIT(1)"
            case StringType => "TEXT"
            case BinaryType => "BLOB"
            case TimestampType => "TIMESTAMP"
            case DateType => "DATE"
            case t: DecimalType => s"DECIMAL($t.precision,$t.scale)"
            case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
          )
      val nullable = if (field.nullable) "" else "NOT NULL"
      sb.append(s", $name $typ $nullable")
    
    if (sb.length < 2) "" else sb.substring(2)
  

  /**
    * Saves the RDD to the database in a single transaction.
    */
  def saveTable(
                 df: DataFrame,
                 url: String,
                 table: String,
                 properties: Properties = new Properties()) 
    val dialect = JdbcDialects.get(url)
    val nullTypes: Array[Int] = df.schema.fields.map  field =>
      dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
        field.dataType match 
          case IntegerType => java.sql.Types.INTEGER
          case LongType => java.sql.Types.BIGINT
          case DoubleType => java.sql.Types.DOUBLE
          case FloatType => java.sql.Types.REAL
          case ShortType => java.sql.Types.INTEGER
          case ByteType => java.sql.Types.INTEGER
          case BooleanType => java.sql.Types.BIT
          case StringType => java.sql.Types.CLOB
          case BinaryType => java.sql.Types.BLOB
          case TimestampType => java.sql.Types.TIMESTAMP
          case DateType => java.sql.Types.DATE
          case t: DecimalType => java.sql.Types.DECIMAL
          case _ => throw new IllegalArgumentException(
            s"Can't translate null value for field $field")
        )
    

    val rddSchema = df.schema
    val driver: String = DriverRegistry.getDriverClassName(url)
    val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
    df.foreachPartition  iterator =>
      savePartition(getConnection, table, iterator, rddSchema, nullTypes)
    
  


下面需要将这个类方法依赖的类放到当前目录下:

JdbcDialects.scala:

package JDBC_MySql

import java.sql.Connection, PreparedStatement
import java.util.Properties

//import com.besttone.utils.JDBCRDD, JdbcDialects
import org.apache.spark.Logging
import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame, Row, SaveMode

import scala.util.Try

/**
  * Util functions for JDBC tables.
  */
object JdbcUtils extends Logging 

  val  mode = SaveMode.Append


  def jdbc(url: String,df: DataFrame, table: String, connectionProperties: Properties): Unit = 
    val props = new Properties()
    props.putAll(connectionProperties)
    val conn = JdbcUtils.createConnection(url, props)

    try 
      var tableExists = JdbcUtils.tableExists(conn, table)

      if (mode == SaveMode.Ignore && tableExists) 
        return
      

      if (mode == SaveMode.ErrorIfExists && tableExists) 
        sys.error(s"Table $table already exists.")
      

      if (mode == SaveMode.Overwrite && tableExists) 
        JdbcUtils.dropTable(conn, table)
        tableExists = false
      

      // Create the table if the table didn't exist.
      if (!tableExists) 
        val schema = JdbcUtils.schemaString(df, url)
        val sql = s"CREATE TABLE $table ($schema)"
        conn.prepareStatement(sql).executeUpdate()
      
     finally 
      conn.close()
    

    JdbcUtils.saveTable(df, url, table, props)
  

  /**
    * Establishes a JDBC connection.
    */
  def createConnection(url: String, connectionProperties: Properties): Connection = 
    JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
  

  /**
    * Returns true if the table already exists in the JDBC database.
    */
  def tableExists(conn: Connection, table: String): Boolean = 
    // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
    // SQL database systems, considering "table" could also include the database name.
    Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
  

  /**
    * Drops a table from the JDBC database.
    */
  def dropTable(conn: Connection, table: String): Unit = 
    conn.prepareStatement(s"DROP TABLE $table").executeUpdate()
  

  /**
    * Returns a PreparedStatement that inserts a row into table via conn.
    */
  def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = 
    val fields = rddSchema.fields
    val fieldsSql = new StringBuilder(s"(")
    var i=0;
    for(f <- fields)
      fieldsSql.append(f.name)

      if(i==fields.length-1)
        fieldsSql.append(")")
      else
        fieldsSql.append(",")
      
      i+=1
    

    val sql = new StringBuilder(s"INSERT INTO $table ")
    sql.append(fieldsSql.toString())
    sql.append(" VALUES (")
    var fieldsLeft = rddSchema.fields.length
    while (fieldsLeft > 0) 
      sql.append("?")
      if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
      fieldsLeft = fieldsLeft - 1
    
    //println(sql.toString())
    conn.prepareStatement(sql.toString())
  

  /**
    * Saves a partition of a DataFrame to the JDBC database.  This is done in
    * a single database transaction in order to avoid repeatedly inserting
    * data as much as possible.
    *
    * It is still theoretically possible for rows in a DataFrame to be
    * inserted into the database more than once if a stage somehow fails after
    * the commit occurs but before the stage can return successfully.
    *
    * This is not a closure inside saveTable() because apparently cosmetic
    * implementation changes elsewhere might easily render such a closure
    * non-Serializable.  Instead, we explicitly close over all variables that
    * are used.
    */
  def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     nullTypes: Array[Int]): Iterator[Byte] = 
    val conn = getConnection()
    var committed = false
    try 
      conn.setAutoCommit(false) // Everything in the same db transaction.
      val stmt = insertStatement(conn, table, rddSchema)
      try 
        while (iterator.hasNext) 
          val row = iterator.next()
          val numFields = rddSchema.fields.length
          var i = 0
          while (i < numFields) 
            if (row.isNullAt(i)) 
              stmt.setNull(i + 1, nullTypes(i))
             else 
              rddSchema.fields(i).dataType match 
                case IntegerType => stmt.setInt(i + 1, row.getInt(i))
                case LongType => stmt.setLong(i + 1, row.getLong(i))
                case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
                case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
                case ShortType => stmt.setInt(i + 1, row.getShort(i))
                case ByteType => stmt.setInt(i + 1, row.getByte(i))
                case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
                case StringType => stmt.setString(i + 1, row.getString(i))
                case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
                case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
                case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
                case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
                case _ => throw new IllegalArgumentException(
                  s"Can't translate non-null value for field $i")
              
            
            i = i + 1
          
          stmt.executeUpdate()
        
       finally 
        stmt.close()
      
      conn.commit()
      committed = true
     finally 
      if (!committed) 
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        conn.rollback()
        conn.close()
       else 
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try 
          conn.close()
         catch 
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        
      
    
    Array[Byte]().iterator
  
  /**
    * Compute the schema string for this RDD.
    */
  def schemaString(df: DataFrame, url: String): String = 
    val sb = new StringBuilder()
    val dialect = JdbcDialects.get(url)
    df.schema.fields foreach  field => 
      val name = field.name
      val typ: String =
        dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
          field.dataType match 
            case IntegerType => "INTEGER"
            case LongType => "BIGINT"
            case DoubleType => "DOUBLE PRECISION"
            case FloatType => "REAL"
            case ShortType => "INTEGER"
            case ByteType => "BYTE"
            case BooleanType => "BIT(1)"
            case StringType => "TEXT"
            case BinaryType => "BLOB"
            case TimestampType => "TIMESTAMP"
            case DateType => "DATE"
            case t: DecimalType => s"DECIMAL($t.precision,$t.scale)"
            case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
          )
      val nullable = if (field.nullable) "" else "NOT NULL"
      sb.append(s", $name $typ $nullable")
    
    if (sb.length < 2) "" else sb.substring(2)
  

  /**
    * Saves the RDD to the database in a single transaction.
    */
  def saveTable(
                 df: DataFrame,
                 url: String,
                 table: String,
                 properties: Properties = new Properties()) 
    val dialect = JdbcDialects.get(url)
    val nullTypes: Array[Int] = df.schema.fields.map  field =>
      dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
        field.dataType match 
          case IntegerType => java.sql.Types.INTEGER
          case LongType => java.sql.Types.BIGINT
          case DoubleType => java.sql.Types.DOUBLE
          case FloatType => java.sql.Types.REAL
          case ShortType => java.sql.Types.INTEGER
          case ByteType => java.sql.Types.INTEGER
          case BooleanType => java.sql.Types.BIT
          case StringType => java.sql.Types.CLOB
          case BinaryType => java.sql.Types.BLOB
          case TimestampType => java.sql.Types.TIMESTAMP
          case DateType => java.sql.Types.DATE
          case t: DecimalType => java.sql.Types.DECIMAL
          case _ => throw new IllegalArgumentException(
            s"Can't translate null value for field $field")
        )
    

    val rddSchema = df.schema
    val driver: String = DriverRegistry.getDriverClassName(url)
    val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
    df.foreachPartition  iterator =>
      savePartition(getConnection, table, iterator, rddSchema, nullTypes)
    
  


JDBCRDD.scala:

package JDBC_MySql

/**
  * Created by zhoubh on 2016/7/22.
  */
import java.sql.Connection, DriverManager
import java.util.Properties

import org.apache.spark.Logging
import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry


private  object JDBCRDD extends Logging 
  def getConnector(driver: String, url: String, properties: Properties): () => Connection = 
    () => 
      try 
        if (driver != null) DriverRegistry.register(driver)
       catch 
        case e: ClassNotFoundException =>
          logWarning(s"Couldn't find class $driver", e)
      
      DriverManager.getConnection(url, properties)
    
  

调用测试main函数类:mysqlDB:

package JDBC_MySql

import java.util.Properties

import org.apache.spark.sql.SaveMode
import org.apache.spark.SparkConf, SparkContext

/**
  * Created by zhoubh on 2016/7/20.
  */
object mysqlDB 

  case class zbh_test(day_id:String, prvnce_id:String,pv_cnts:Int)

  def main(args: Array[String]) 


    val conf = new SparkConf().setAppName("mysql").setMaster("local[4]")
    val sc = new SparkContext(conf)
    //sc.addJar("D:\\\\workspace\\\\sparkApp\\\\lib\\\\mysql-connector-java-5.0.8-bin.jar")
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)



     //定义mysql信息
    val jdbcDF = sqlContext.read.format("jdbc").options(
      Map("url"->"jdbc:mysql://localhost:3306/db_ldjs",
    "dbtable"->"(select imei,region,city,company,name from tb_user_imei) as some_alias",
    "driver"->"com.mysql.jdbc.Driver",
    "user"-> "root",
    //"partitionColumn"->"day_id",
    "lowerBound"->"0",
    "upperBound"-> "1000",
    //"numPartitions"->"2",
    "fetchSize"->"100",
    "password"->"123456")).load()


    jdbcDF.collect().take(20).foreach(println)
    //jdbcDF.rdd.saveAsTextFile("C:/Users/zhoubh/Downloads/abi_sum")
    val url="jdbc:mysql://localhost:3306/db_ldjs"
    val prop=new Properties()
    prop.setProperty("user","root")
    prop.setProperty("password","123456")
    //jdbcDF.write.mode(SaveMode.Overwrite).jdbc(url,"zfs_test",prop)
    jdbcDF.write.mode(SaveMode.Append).jdbc(url,"zbh_test",prop)
    JdbcUtils.jdbc(url,jdbcDF,"zfs_test1",prop)

     //org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.saveTable(jdbcDF,url,"zbh_test",prop)
    #然后进行groupby 操作,获取数据集合
//    val abi_sum_area = abi_sum.groupBy("date_time", "area_name")
//
    #计算数目,并根据数目进行降序排序
//    val sorted = abi_sum_area.count().orderBy("count")
//
    #显示前10条
//    sorted.show(10)
//
    #存储到文件(这里会有很多分片文件。。。)
//    sorted.rdd.saveAsTextFile("C:/Users/zhoubh/Downloads/sparktest/flight_top")
//
//
    #存储到mysql表里
//    //sorted.write.jdbc(url,"table_name",prop)


  


调试运行就可以看到效果啦:




在线上机器的提交运行codedemo:

 spark-submit --class com.besttone.UserOnlineAnalysis --master yarn-client --executor-memory 2g --num-executors 3 file:///home/hadoop/test/sparkApp.jar test/apponoff.bz2 test/out22
 spark-submit --class com.besttone.utils.TestMysql --master yarn-client --executor-memory 2g --num-executors 3 file:///home/hadoop/file/sparkApp.jar 
 spark-submit --class com.besttone.app.Appo2oProcess   --master yarn-client --executor-memory 2g --num-executors 3 file:///home/hadoop/file/sparkApp.jar /user/hadoop/20160804/appo2olog /user/hive/warehouse/tmp_appo2olog1


以上是关于scala实战之spark源码修改(能够将DataFrame按字段增量写入mysql数据表)的主要内容,如果未能解决你的问题,请参考以下文章

Scala实战高手****第13课Scala模式匹配实战和Spark源码鉴赏

Scala实战高手****第6课 :零基础实战Scala集合操作及Spark源码解析

Scala实战高手****第14课Scala集合上的函数式编程实战及Spark源码鉴赏

Scala实战高手****第7课:零基础实战Scala面向对象编程及Spark源码解析

scala实战之spark读取mysql数据表并存放到mysql库中编程实例

第43讲:Scala中类型变量Bounds代码实战及其在Spark中的应用源码解析