鐢⊿park-Scala璁粌LightGBM妯″瀷

Posted 绠楁硶缇庨灞?/a> 鐢⊿

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了鐢⊿park-Scala璁粌LightGBM妯″瀷相关的知识,希望对你有一定的参考价值。

  浠婃棩琛ㄦ儏 馃構 

Spark-scala 鍙互浣跨敤LightGBM妯″瀷锛屾棦鍙互杩涜鍒嗗竷寮忚缁冿紝涔熷彲浠ヨ繘琛屽垎甯冨紡棰勬祴锛屾敮鎸佸悇绉嶅弬鏁拌缃€?
鏀寔妯″瀷淇濆瓨锛屽苟涓斾繚瀛樺悗鐨勬ā鍨嬪拰Python绛夎瑷€鏄彲浠ョ浉浜掕皟鐢ㄧ殑銆?
闇€瑕佹敞鎰忕殑鏄紝Spark-scala璁粌LightGBM妯″瀷鏃讹紝 杈撳叆妯″瀷鐨勮缁冩暟鎹泦闇€瑕佸鐞嗘垚涓€涓狣ataFrame锛岀敤spark.ml.feature.VectorAssembler灏嗗鍒楃壒寰佽浆鎹㈡垚涓€涓?features鍚戦噺鍒楋紝label浣滀负鍙﹀涓€鍒椼€?

涓€锛岀幆澧冮厤缃?/span>

spark-scala瑕佷娇鐢╨ightgbm妯″瀷锛宲om鏂囦欢涓閰嶇疆濡備笅渚濊禆銆?

  
    
    
  
<dependency>
<groupId>org.apache.spark </groupId>
<artifactId>spark-mllib_${scala.version} </artifactId>
<version>${spark.version} </version>
<!--spark-ml瑕佸幓鎺塸mml-model渚濊禆-->
<exclusions>
     <exclusion>
         <groupId>org.jpmml </groupId>
         <artifactId>pmml-model </artifactId>
     </exclusion>
</exclusions>
</dependency>

<dependency>
     <groupId>org.jpmml </groupId>
     <artifactId>jpmml-sparkml </artifactId>
     <version>1.3.4 </version>
</dependency>
<dependency>
     <groupId>org.jpmml </groupId>
     <artifactId>jpmml-lightgbm </artifactId>
     <version>1.3.4 </version>
</dependency>

浜岋紝鑼冧緥浠g爜

涓嬮潰鎴戜滑浠ヤ簩鍒嗙被闂涓轰緥锛屾寜鐓у涓嬪嚑涓ぇ瀹剁啛鎮夌殑姝ラ杩涜鑼冧緥浠g爜婕旂ず銆?
  • 1锛屽噯澶囨暟鎹?
  • 2锛屽畾涔夋ā鍨?
  • 3锛岃缁冩ā鍨?
  • 4锛岃瘎浼版ā鍨?
  • 5锛屼娇鐢ㄦā鍨?
  • 6锛屼繚瀛樻ā鍨?

  
    
    
  
import org.apache.spark.sql. SparkSession
import org.apache.spark.sql. DataFrame
import org.apache.spark.sql.types.{ DoubleTypeStringTypeStructFieldStructTypeIntegerType}
import org.apache.spark.ml. Pipeline
import org.apache.spark.ml.evaluation. BinaryClassificationEvaluator
import org.apache.spark.ml.evaluation. MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg. Vector
import org.apache.spark.ml.feature. VectorAssembler
import org.apache.spark.ml.attribute. Attribute
import org.apache.spark.ml.feature.{ IndexToStringStringIndexer}
import com.microsoft.ml.spark.{lightgbm=>lgb}
import com.google.gson.{ JsonObjectJsonParser}
import scala.collection. JavaConverters._

object LgbDemo extends Serializable {
    
     def printlog(info: String):  Unit ={
         val dt =  new java.text. SimpleDateFormat( "yyyy-MM-dd HH:mm:ss").format( new java.util. Date)
        println( "=========="* 8+dt)
        println(info+ "\n")
    }
    
     def main(args: Array[ String]): Unit= {


     /*================================================================================*/
     //  涓€锛屽姞杞芥暟鎹?/span>
     /*================================================================================*/
    printlog( "step1: preparing data ...")

     //鍔犺浇鏁版嵁
     val spark =  SparkSession.builder().getOrCreate()
     val dfdata_raw = spark.read.option( "header", "true")
        .option( "delimiter""\t")
        .option( "inferschema""true")
        .option( "nullValue", "")
        .csv( "data/breast_cancer.csv")

    dfdata_raw.sample( false, 0.1, 1).printSchema 

     //灏嗙壒寰佺粍鍚堟垚features鍚戦噺
     val feature_cols = dfdata_raw.columns.filter(! Array( "label").contains(_)) 
     val cate_cols =  Array( "mean_radius", "mean_texture"


     val vectorAssembler =  new  VectorAssembler().
      setInputCols(feature_cols).
      setOutputCol( "features")

     val dfdata = vectorAssembler.transform(dfdata_raw).select( "features""label")
     val  Array(dftrain,dfval)  = dfdata.randomSplit( Array( 0.7.3),  666)

     //鍚勪釜鐗瑰緛鐨勫悕瀛楀瓨鍌ㄥ湪浜唖chema 鐨?nbsp;metadata涓簡, 鎵€浠ュ彲浠ョ敤鐗瑰緛鍚嶆寚瀹氱被鍒壒寰?nbsp;
    println(dfdata.schema( "features").metadata)
    dfdata.show( 10

     /*================================================================================*/
     //  浜岋紝瀹氫箟妯″瀷
     /*================================================================================*/
    printlog( "step2: defining model ...")

     val lgbclassifier =  new lgb. LightGBMClassifier()
      .setNumIterations( 100)
      .setLearningRate( 0.1)
      .setNumLeaves( 31)
      .setMinSumHessianInLeaf( 0.001)
      .setMaxDepth( -1)
      .setBoostFromAverage( false)
      .setFeatureFraction( 1.0)
      .setMaxBin( 255)
      .setLambdaL1( 0.0)
      .setLambdaL2( 0.0)
      .setBaggingFraction( 1.0)
      .setBaggingFreq( 0)
      .setBaggingSeed( 1)
      .setBoostingType( "gbdt"//rf銆乨art銆乬oss
      .setCategoricalSlotNames(cate_cols)
      .setObjective( "binary"//binary, multiclass
      .setFeaturesCol( "features"
      .setLabelCol( "label")

    println(lgbclassifier.explainParams) 


     /*================================================================================*/
     //  涓夛紝璁粌妯″瀷
     /*================================================================================*/
    printlog( "step3: training model ...")

     val lgbmodel = lgbclassifier.fit(dftrain)

     val feature_importances = lgbmodel.getFeatureImportances( "gain")
     val arr = feature_cols.zip(feature_importances).sortBy[ Double](t=> -t._2)
     val dfimportance = spark.createDataFrame(arr).toDF( "feature_name", "feature_importance(gain)")

    dfimportance.show( 100)


     /*================================================================================*/
     //  鍥涳紝璇勪及妯″瀷
     /*================================================================================*/
    printlog( "step4: evaluating model ...")

     val evaluator =  new  BinaryClassificationEvaluator()
      .setLabelCol( "label")
      .setRawPredictionCol( "rawPrediction")
      .setMetricName( "areaUnderROC")

     val dftrain_result = lgbmodel.transform(dftrain)
     val dfval_result = lgbmodel.transform(dfval)

     val train_auc  = evaluator.evaluate(dftrain_result)
     val val_auc = evaluator.evaluate(dfval_result)
    println( s"train_auc = ${train_auc}")
    println( s"val_auc = ${val_auc}")


     /*================================================================================*/
     //  浜旓紝浣跨敤妯″瀷
     /*================================================================================*/
    printlog( "step5: using model ...")

     //鎵归噺棰勬祴
     val dfpredict = lgbmodel.transform(dfval)
    dfpredict.sample( false, 0.1, 1).show( 20)

     //瀵瑰崟涓牱鏈繘琛岄娴?/span>
     val features = dfval.head().getAs[ Vector]( "features")
     val single_result = lgbmodel.predict(features)

    println(single_result)


     /*================================================================================*/
     //  鍏紝淇濆瓨妯″瀷
     /*================================================================================*/
    printlog( "step6: saving model ...")

     //淇濆瓨鍒伴泦缇?澶氭枃浠?/span>
    lgbmodel.write.overwrite().save( "lgbmodel.model")
     //鍔犺浇闆嗙兢妯″瀷
    println( "load model ...")
     val lgbmodel_loaded = lgb. LightGBMClassificationModel.load( "lgbmodel.model")
     val dfresult = lgbmodel_loaded.transform(dfval)
    dfresult.show() 

     //淇濆瓨鍒版湰鍦?鍗曟枃浠?鍜孭ython鎺ュ彛鍏煎
     //lgbmodel.saveNativeModel("lgb_model",true)
     //鍔犺浇鏈湴妯″瀷
     //val lgbmodel_loaded = LightGBMClassificationModel.loadNativeModelFromFile("lgb_model")
    
    }
    
}

涓夛紝杈撳嚭鍙傝€?/span>

杩愯濡備笂浠g爜涔嬪悗锛屽彲浠ュ緱鍒板涓嬭緭鍑恒€?
娉ㄦ剰 println(lgbclassifier.explainParams)鍙互鑾峰彇LightGBM妯″瀷鍚勪釜鍙傛暟鐨勫惈涔変互鍙婇粯璁ゅ€笺€?

  
    
    
  
================================================================================2021-07-17 22:16:29
step1: preparing data ...

root
|-- mean_radius: integer (nullable = true)
|-- mean_texture: integer (nullable = true)
|-- mean_perimeter: double (nullable = true)
|-- mean_area: double (nullable = true)
|-- mean_smoothness: double (nullable = true)
|-- mean_compactness: double (nullable = true)
|-- mean_concavity: double (nullable = true)
|-- mean_concave_points: double (nullable = true)
|-- mean_symmetry: double (nullable = true)
|-- mean_fractal_dimension: double (nullable = true)
|-- radius_error: double (nullable = true)
|-- texture_error: double (nullable = true)
|-- perimeter_error: double (nullable = true)
|-- area_error: double (nullable = true)
|-- smoothness_error: double (nullable = true)
|-- compactness_error: double (nullable = true)
|-- concavity_error: double (nullable = true)
|-- concave_points_error: double (nullable = true)
|-- symmetry_error: double (nullable = true)
|-- fractal_dimension_error: double (nullable = true)
|-- worst_radius: double (nullable = true)
|-- worst_texture: double (nullable = true)
|-- worst_perimeter: double (nullable = true)
|-- worst_area: double (nullable = true)
|-- worst_smoothness: double (nullable = true)
|-- worst_compactness: double (nullable = true)
|-- worst_concavity: double (nullable = true)
|-- worst_concave_points: double (nullable = true)
|-- worst_symmetry: double (nullable = true)
|-- worst_fractal_dimension: double (nullable = true)
|-- label: integer (nullable = true)

{"ml_attr":{"attrs":{"numeric":[{"idx":0,"name":"mean_radius"},{"idx":1,"name":"mean_texture"},{"idx":2,"name":"mean_perimeter"},{"idx":3,"name":"mean_area"},{"idx":4,"name":"mean_smoothness"},{"idx":5,"name":"mean_compactness"},{"idx":6,"name":"mean_concavity"},{"idx":7,"name":"mean_concave_points"},{"idx":8,"name":"mean_symmetry"},{"idx":9,"name":"mean_fractal_dimension"},{"idx":10,"name":"radius_error"},{"idx":11,"name":"texture_error"},{"idx":12,"name":"perimeter_error"},{"idx":13,"name":"area_error"},{"idx":14,"name":"smoothness_error"},{"idx":15,"name":"compactness_error"},{"idx":16,"name":"concavity_error"},{"idx":17,"name":"concave_points_error"},{"idx":18,"name":"symmetry_error"},{"idx":19,"name":"fractal_dimension_error"},{"idx":20,"name":"worst_radius"},{"idx":21,"name":"worst_texture"},{"idx":22,"name":"worst_perimeter"},{"idx":23,"name":"worst_area"},{"idx":24,"name":"worst_smoothness"},{"idx":25,"name":"worst_compactness"},{"idx":26,"name":"worst_concavity"},{"idx":27,"name":"worst_concave_points"},{"idx":28,"name":"worst_symmetry"},{"idx":29,"name":"worst_fractal_dimension"}]},"num_attrs":30}}
+--------------------+-----+
| features|label|
+--------------------+-----+
|[17.0,10.0,122.8,...| 0|
|[20.0,17.0,132.9,...| 0|
|[19.0,21.0,130.0,...| 0|
|[11.0,20.0,77.58,...| 0|
|[20.0,14.0,135.1,...| 0|
|[12.0,15.0,82.57,...| 0|
|[18.0,19.0,119.6,...| 0|
|[13.0,20.0,90.2,5...| 0|
|[13.0,21.0,87.5,5...| 0|
|[12.0,24.0,83.97,...| 0|
+--------------------+-----+
only showing top 10 rows

================================================================================2021-07-17 22:16:29
step2: defining model ...

baggingFraction: Bagging fraction (default: 1.0, current: 1.0)
baggingFreq: Bagging frequency (default: 0, current: 0)
baggingSeed: Bagging seed (default: 3, current: 1)
boostFromAverage: Adjusts initial score to the mean of labels for faster convergence (default: true, current: false)
boostingType: Default gbdt = traditional Gradient Boosting Decision Tree. Options are: gbdt, gbrt, rf (Random Forest), random_forest, dart (Dropouts meet Multiple Additive Regression Trees), goss (Gradient-based One-Side Sampling). (default: gbdt, current: gbdt)
categoricalSlotIndexes: List of categorical column indexes, the slot index in the features column (undefined)
categoricalSlotNames: List of categorical column slot names, the slot name in the features column (current: [Ljava.lang.String;@351fb3fc)
defaultListenPort: The default listen port on executors, used for testing (default: 12400)
earlyStoppingRound: Early stopping round (default: 0)
featureFraction: Feature fraction (default: 1.0, current: 1.0)
featuresCol: features column name (default: features, current: features)
initScoreCol: The name of the initial score column, used for continued training (undefined)
isProvideTrainingMetric: Whether output metric result over training dataset. (default: false)
isUnbalance: Set to true if training data is unbalanced in binary classification scenario (default: false)
labelCol: label column name (default: label, current: label)
lambdaL1: L1 regularization (default: 0.0, current: 0.0)
lambdaL2: L2 regularization (default: 0.0, current: 0.0)
learningRate: Learning rate or shrinkage rate (default: 0.1, current: 0.1)
maxBin: Max bin (default: 255, current: 255)
maxDepth: Max depth (default: -1, current: -1)
minSumHessianInLeaf: Minimal sum hessian in one leaf (default: 0.001, current: 0.001)
modelString: LightGBM model to retrain (default: )
numBatches: If greater than 0, splits data into separate batches during training (default: 0)
numIterations: Number of iterations, LightGBM constructs num_class * num_iterations trees (default: 100, current: 100)
numLeaves: Number of leaves (default: 31, current: 31)
objective: The Objective. For regression applications, this can be: regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. For classification applications, this can be: binary, multiclass, or multiclassova. (default: binary, current: binary)
parallelism: Tree learner parallelism, can be set to data_parallel or voting_parallel (default: data_parallel)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
timeout: Timeout in seconds (default: 1200.0)
useBarrierExecutionMode: Use new barrier execution mode in Beta testing, off by default. (default: false)
validationIndicatorCol: Indicates whether the row is for training or validation (undefined)
verbosity: Verbosity where lt 0 is Fatal, eq 0 is Error, eq 1 is Info, gt 1 is Debug (default: 1)
weightCol: The name of the weight column (undefined)
================================================================================2021-07-17 22:16:29
step3: training model ...

+--------------------+------------------------+
| feature_name|feature_importance(gain)|
+--------------------+------------------------+
| worst_area| 974.9349449056517|
| worst_perimeter| 885.3691593843923|
|worst_concave_points| 255.67364284247745|
| mean_concave_points| 250.21955942230738|
| worst_texture| 151.07745621304454|
| area_error| 65.75557372416814|
| worst_smoothness| 62.29973236144293|
| mean_smoothness| 19.902610011957194|
| worst_radius| 16.8275272153341|
| mean_area| 12.41261211467938|
| mean_perimeter| 12.127510878875537|
| worst_concavity| 11.414242858900646|
| compactness_error| 10.996194651604892|
| mean_texture| 9.274276675339683|
| concavity_error| 8.009578698471008|
| symmetry_error| 7.93458393366217|
| radius_error| 7.357747321194173|
| worst_symmetry| 5.951699663755868|
|fractal_dimension...| 4.811246624133022|
|concave_points_error| 4.73140145466917|
| worst_compactness| 4.469820723182832|
| texture_error| 4.356178728700959|
| mean_compactness| 3.123736411467967|
| mean_symmetry| 1.9968633063354835|
| mean_concavity| 1.9701941942285224|
| smoothness_error| 1.673042485476758|
|worst_fractal_dim...| 1.3582115541525612|
|mean_fractal_dime...| 0.6050912755332459|
| perimeter_error| 0.3889888676278275|
| mean_radius| 5.684356116234315...|
+--------------------+------------------------+

================================================================================2021-07-17 22:16:30
step4: evaluating model ...

train_auc = 1.0
val_auc = 0.9890340267698758
================================================================================2021-07-17 22:16:31
step5: using model ...

+--------------------+-----+--------------------+--------------------+----------+
| features|label| rawPrediction| probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[9.0,12.0,60.34,2...| 1|[-10.570726382467...|[-9.5707263824679...| 1.0|
|[10.0,16.0,65.85,...| 1|[-10.120435089856...|[-9.1204350898567...| 1.0|
|[10.0,21.0,68.51,...| 1|[-8.8020346337692...|[-7.8020346337692...| 1.0|
|[11.0,14.0,73.53,...| 1|[-10.315758226759...|[-9.3157582267596...| 1.0|
|[11.0,15.0,73.38,...| 1|[-10.086077130817...|[-9.0860771308174...| 1.0|
|[11.0,16.0,74.72,...| 1|[-6.9649803118554...|[-5.9649803118554...| 1.0|
|[11.0,17.0,71.25,...| 1|[-10.694667171248...|[-9.6946671712481...| 1.0|
|[11.0,17.0,75.27,...| 1|[-9.0156792680894...|[-8.0156792680894...| 1.0|
|[11.0,18.0,75.17,...| 1|[-5.7513546284621...|[-4.7513546284621...| 1.0|
|[11.0,18.0,76.38,...| 1|[-4.3134421808792...|[-3.3134421808792...| 1.0|
|[12.0,15.0,82.57,...| 0|[2.49310942805160...|[3.49310942805160...| 0.0|
|[12.0,17.0,78.27,...| 1|[-10.516042459712...|[-9.5160424597122...| 1.0|
|[12.0,18.0,83.19,...| 1|[-9.4899850168431...|[-8.4899850168431...| 1.0|
|[12.0,22.0,78.75,...| 1|[-8.9917629958319...|[-7.9917629958319...| 1.0|
|[14.0,15.0,92.68,...| 1|[-7.2724968676775...|[-6.2724968676775...| 1.0|
|[14.0,15.0,95.77,...| 1|[-5.0143190624015...|[-4.0143190624015...| 1.0|
|[14.0,16.0,96.22,...| 1|[-5.3849620427583...|[-4.3849620427583...| 1.0|
|[14.0,19.0,97.83,...| 1|[-3.3292007261919...|[-2.3292007261919...| 1.0|
|[16.0,14.0,104.3,...| 1|[4.66077729134426...|[5.66077729134426...| 0.0|
|[19.0,24.0,122.0,...| 0|[10.1503565558166...|[11.1503565558166...| 0.0|
+--------------------+-----+--------------------+--------------------+----------+

1.0
================================================================================2021-07-17 22:16:31
step6: saving model ...

load model ...

鏀跺伐銆傪煒?
鏈枃Spark-scala 浣跨敤 LightGBM 妯″瀷璁粌 浜屽垎绫绘ā鍨?浠g爜鍜屾暟鎹泦锛屼互鍙婅缁?澶氬垎绫绘ā鍨?鍜?鍥炲綊妯″瀷 鐨勮寖渚嬩唬鐮佸拰鏁版嵁闆嗭紝鍙互鍦ㄥ叕浼楀彿 绠楁硶缇庨灞?/strong>鍚庡彴鍥炲鍏抽敭璇? spark+lightgbm 鑾峰彇銆?
涓囨按鍗冨北鎬绘槸鎯咃紝鐐逛釜鍦ㄧ湅琛屼笉琛岋紵馃構