重新加载的 Spark 模型似乎不起作用
Posted
技术标签:
【中文标题】重新加载的 Spark 模型似乎不起作用【英文标题】:Reloaded Spark model does not seem to work 【发布时间】:2016-08-13 22:46:56 【问题描述】:我正在从 CSV 文件中训练和保存模型。 第一步一切正常。保存模型后,我试图加载并使用保存的模型和新数据,但它不起作用。
有什么问题?
培训 Java 文件
SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh");
SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate();
JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classifications1.csv").javaRDD()
.map(new Function<String, Cobj>()
@Override
public Cobj call(String line) throws Exception
String[] parts = line.split(",");
Cobj c = new Cobj();
c.setClassName(parts[1].trim());
c.setProductName(parts[0].trim());
return c;
);
Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);
//StringIndexer
StringIndexer classIndexer = new StringIndexer()
.setHandleInvalid("skip")
.setInputCol("className")
.setOutputCol("label");
StringIndexerModel classIndexerModel=classIndexer.fit(mainDataset);
//Tokenizer
Tokenizer tokenizer = new Tokenizer()
.setInputCol("productName")
.setOutputCol("words");
//HashingTF
HashingTF hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features");
DecisionTreeClassifier decisionClassifier = new DecisionTreeClassifier ()
.setLabelCol("label")
.setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] classIndexer,tokenizer,hashingTF,decisionClassifier);
Dataset<Row>[] splits = mainDataset.randomSplit(new double[]0.8, 0.2);
Dataset<Row> train = splits[0];
Dataset<Row> test = splits[1];
PipelineModel pipelineModel = pipeline.fit(train);
Dataset<Row> result = pipelineModel.transform(test);
pipelineModel.write().overwrite().save(savePath+"DecisionTreeClassificationModel");
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("PredictedClassName")
.setLabels(classIndexerModel.labels());
result=labelConverter.transform(result);
result.show(num,false);
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
输出:
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
|className |productName |label|words |features |rawPrediction |probability |prediction|PredictedClassName |
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
|Apple iPhone 6S 16GB |Apple IPHONE 6S 16GB SGAY Telefon |2.0 |[apple, iphone, 6s, 16gb, sgay, telefon] |(262144,[27536,56559,169565,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0]) |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0 |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S 16GB |Apple iPhone 6S 16 GB Space Gray MKQJ2TU/A |2.0 |[apple, iphone, 6s, 16, gb, space, gray, mkqj2tu/a] |(262144,[10879,56559,95900,139131,175329,175778,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0 |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S 16GB |iPhone 6s 16GB |2.0 |[iphone, 6s, 16gb] |(262144,[27536,56559,210029],[1.0,1.0,1.0]) |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0 |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S Plus 128GB|Apple IPHONE 6S PLUS 128GB SG Telefon |4.0 |[apple, iphone, 6s, plus, 128gb, sg, telefon] |(262144,[56559,99916,137263,175839,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0,1.0]) |[0.0,0.0,0.0,0.0,2.0]|[0.0,0.0,0.0,0.0,1.0]|4.0 |Apple iPhone 6S Plus 128GB|
|Apple iPhone 6S Plus 16GB |Iphone 6S Plus 16GB SpaceGray - Apple Türkiye|1.0 |[iphone, 6s, plus, 16gb, spacegray, -, apple, türkiye]|(262144,[27536,45531,46750,56559,59104,99916,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]) |[0.0,5.0,0.0,0.0,0.0]|[0.0,1.0,0.0,0.0,0.0]|1.0 |Apple iPhone 6S Plus 16GB |
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
Accuracy = 1.0
加载 Java 文件
SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh");
SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate();
JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classificationsTest.csv").javaRDD()
.map(new Function<String, Cobj>()
@Override
public Cobj call(String line) throws Exception
String[] parts = line.split(",");
Cobj c = new Cobj();
c.setClassName("?");
c.setProductName(parts[0].trim());
return c;
);
Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);
mainDataset.show(100,false);
PipelineModel pipelineModel = PipelineModel.load(savePath+"DecisionTreeClassificationModel");
Dataset<Row> result = pipelineModel.transform(mainDataset);
result.show(100,false);
输出:
+---------+-----------+-----+-----+--------+-------------+-----------+----------+
|className|productName|label|words|features|rawPrediction|probability|prediction|
+---------+-----------+-----+-----+--------+-------------+-----------+----------+
+---------+-----------+-----+-----+--------+-------------+-----------+----------+
【问题讨论】:
【参考方案1】:我从管道中删除了 StringIndexer 并保存为“StringIndexer”。 在第二个文件中;管道加载后,我加载了 StringIndexer 以将其转换为预测标签。
【讨论】:
以上是关于重新加载的 Spark 模型似乎不起作用的主要内容,如果未能解决你的问题,请参考以下文章
使用 pyximport reload_support 重新加载 Cython 不起作用
extjs 嵌套数据网格过滤器和重新加载在 viewModel 上不起作用
驱动程序重新启动后 Spark Streaming 检查点不起作用