Apache Spark:Mllib之决策树的操作(java)
Posted 你是小KS
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Apache Spark:Mllib之决策树的操作(java)相关的知识,希望对你有一定的参考价值。
当前版本:spark 2.4.6
1. 声明
当前内容主要用于本人学习Spark ML的知识,了解决策树和决策森林,当前内容主要参考Spark高级数据分析的第4章用决策树算法预测森林植被
,由于原内容使用scala编写,这里转换为java方式实现
数据准备:数据下载地址
抽掉最后两行数据作为预测数据
2384,170,15,60,5,90,230,245,143,864,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3
2383,165,13,60,4,67,231,244,141,875,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3
2.主要代码
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.RandomForest;
/*import org.apache.spark.mllib.regression.LabeledPoint;*/
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;
/**
*
* @author hy
* @createTime 2021-09-11 08:26:26
* @description 当前内容主要为使用spark科学数据分析中的决策树类型
* 1. 决策树
* 2. 决策森林
*
*/
public class DecisionTreeTest {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
JavaSparkContext jsc = new JavaSparkContext(conf);
jsc.setLogLevel("WARN");
// operation
decisionTreeTest(jsc);
jsc.close();
}
/**
*
* @author hy
* @createTime 2021-09-11 12:53:01
* @description 将数据都放入,做出决策
* @param rawData
* @return
*
*/
private static JavaRDD<LabeledPoint> createLabeledPointDataUsingBefore(JavaRDD<String> rawData) {
JavaRDD<LabeledPoint> data = rawData.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String v1) throws Exception {
String[] strings = v1.split(",");
double[] values = new double[strings.length];
for (int i = 0; i < strings.length; i++) {
values[i] = Double.valueOf(strings[i]);
}
double[] noLastNumValues = Arrays.copyOf(values, values.length - 1);
Vector featureVector = Vectors.dense(noLastNumValues);
// 最后一个值就是lable ,但是必须比7小(表示的就是类别)
double label = values[values.length-1] - 1;
// 决策树要求当前的label从0开始
LabeledPoint labeledPoint = new LabeledPoint(label,featureVector);
return labeledPoint;
}
});
return data;
}
private static int indexOfArray(double[] array,double findValue) {
int index = -1;
for (int i = 0; i < array.length; i++) {
if(findValue==array[i]) {
index=i;
break;
}
}
return index;
}
/**
*
* @author hy
* @createTime 2021-09-11 13:45:03
* @description 使用类型方式
* @param rawData
* @return
*
*/
private static JavaRDD<LabeledPoint> createLabeledPointDataUsingType(JavaRDD<String> rawData) {
JavaRDD<LabeledPoint> data = rawData.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String v1) throws Exception {
String[] strings = v1.split(",");
double[] values = new double[strings.length];
for (int i = 0; i < strings.length; i++) {
values[i] = Double.valueOf(strings[i]);
}
/*
* // 得到类型特征:荒地 double[] wildernessValues = Arrays.copyOfRange(values, 10, 14);
* double wilderness =indexOfArray(wildernessValues,1.0);
*
* // 得到类型特征:土壤 double[] soilValues = Arrays.copyOfRange(values, 14, 54); double
* soil =indexOfArray(soilValues,1.0);
*
* double[] copyOfRange = Arrays.copyOfRange(values, 0, 10); double[] copyOf =
* Arrays.copyOf(copyOfRange, copyOfRange.length+2);
* copyOf[copyOf.length-2]=wilderness; copyOf[copyOf.length-1]=soil;
*/
Vector featureVector = createVectorByLiine(v1);
// 最后一个值就是lable ,但是必须比7小(表示的就是类别)
double label = values[values.length-1] - 1;
// 决策树要求当前的label从0开始
LabeledPoint labeledPoint = new LabeledPoint(label,featureVector);
return labeledPoint;
}
});
return data;
}
/**
*
* @author hy
* @createTime 2021-09-12 08:12:23
* @description 将给定的line数据转换为向量数据
* @param line
* @return
*
*/
private static Vector createVectorByLiine(String line) {
String[] strings = line.split(",");
double[] values = new double[strings.length];
for (int i = 0; i < strings.length; i++) {
values[i] = Double.valueOf(strings[i]);
}
// 得到类型特征:荒地
double[] wildernessValues = Arrays.copyOfRange(values, 10, 14);
double wilderness =indexOfArray(wildernessValues,1.0);
// 得到类型特征:土壤
double[] soilValues = Arrays.copyOfRange(values, 14, 54);
double soil =indexOfArray(soilValues,1.0);
double[] copyOfRange = Arrays.copyOfRange(values, 0, 10);
double[] copyOf = Arrays.copyOf(copyOfRange, copyOfRange.length+2);
copyOf[copyOf.length-2]=wilderness;
copyOf[copyOf.length-1]=soil;
Vector featureVector = Vectors.dense(copyOf);
return featureVector;
}
private static void decisionTreeTest(JavaSparkContext jsc) {
JavaRDD<String> rawData = jsc.textFile("C:\\\\Users\\\\admin\\\\Desktop\\\\mldata\\\\covtype.data");
JavaRDD<LabeledPoint> data = createLabeledPointDataUsingType(rawData);
// map.foreach(x->System.out.println(x));
// 开始准备训练数(训练数据占80%,交叉检验集和测试集各占10%)
JavaRDD<LabeledPoint>[] randomSplit = data.randomSplit(new double[] { 0.8, 0.1, 0.1 });
JavaRDD<LabeledPoint> trainData = randomSplit[0];
JavaRDD<LabeledPoint> cvData = randomSplit[1];
JavaRDD<LabeledPoint> testData = randomSplit[2];
trainData.cache();
cvData.cache();
testData.cache();
// 创建决策树模型(对于具有不同类型的使用trainClassfier,对于使用数值类型使用trainRegressor)
HashMap<Integer,Integer> hashMap = new HashMap<Integer, Integer>();
// 4表示最大深度,100表示桶的数量,7 表示集中目标取值的个数,map保存类型特征信息
// gini代表一种不纯度(不纯度有两种一种是gini另外一种就是熵)
DecisionTreeModel model = DecisionTree.trainClassifier(trainData, 7, hashMap, "gini", 4, 100);
// 使用决策森林模型(太慢了)
//RandomForestModel model = createDecisionRandomForest(trainData);
MulticlassMetrics metrics = getMetrics(model, cvData);
Matrix confusionMatrix = metrics.confusionMatrix();
System.out.println(confusionMatrix);
// 下面两个都是准确度,结果是一样的
System.out.println("准确度:"+metrics.accuracy());
System.out.println("精确度:"+metrics.precision());
// 计算每个类别对其他的精确度
List<Tuple2<Double,Double>> list=new ArrayList<>();
for (int i = 0; i < 7; i++) {
Tuple2<Double, Double> tuple2 = new Tuple2<Double,Double>(metrics.precision(i),metrics.recall(i));
list.add(tuple2);
}
System.out.println("输出与其他对比精度:");
list.forEach(x->{System.out.println(x);});
Double[] trainProbablilities = classProbablilities(testData);
Double[] cvProbablilities = classProbablilities(cvData);
double sum=0.0;
for (int i = 0; i < cvProbablilities.length; i++) {
sum +=cvProbablilities[i]*trainProbablilities[i];
}
System.out.println("准确的评估值:"+sum);
// 计算并获取确定的决策树调优参数(该模型通过设置使用不同的不纯度和桶数量以及决策树的深度方式来实现的,一般是通过循环方式找到最准确度最高的模型进行对testData进行测试)
DecisionTreeModel newModel = DecisionTree.trainClassifier(trainData.union(cvData), 7, hashMap, "entropy", 20, 300);
MulticlassMetrics newMetrics = getMetrics(newModel, testData);
double accuracy = newMetrics.accuracy();
System.out.println("优化后的决策树对cvData的准确度:"+accuracy);
// 决策森林的预测
/* String input="2709,125,28,67,23,3224,253,207,61,4094,0,29"; */
// 决策树的预测
String[] lines= {"2384,170,15,60,5,90,230,245,143,864,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3" ,
"2383,165,13,60,4,67,231,244,141,875,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3"};
for (int i = 0; i < lines.length; i++) {
String line = lines[i];
Vector vector = createVectorByLiine(line);
double predict = newModel.predict(vector);
System.out.println("预测结果为:"+predict);//预测后的标签为2.0,实际需要+1所以预测结果是正确的
}
}
/**
*
* @author hy
* @createTime 2021-09-11 13:50:08
* @description 创建决策森林
*
*/
private static RandomForestModel createDecisionRandomForest(JavaRDD<LabeledPoint> trainData) {
Map<Integer, Integer> hashMap=new HashMap<>();
hashMap.put(10, 4);
hashMap.put(11, 40);
// 20 表示构建20个决策树,20分钟都没有执行完毕
RandomForestModel randomForestModel = RandomForest.trainClassifier(trainData, 7, hashMap, 20, "auto","entropy",30,300, 10);
return randomForestModel;
}
// 计算该类型在所占的比例
private static Double[] classProbablilities(JavaRDD<LabeledPoint> data) {
Map<Double, Long> countByValue = data.map(x->x.label()).countByValue();
List<Tuple2<Double,Long>> counts=new ArrayList<>();
Long sum = 0L;
for (Entry<Double, Long> entry : countByValue.entrySet()) {
counts.add(new Tuple2<Double, Long>(entry.getKey(), entry.getValue()));
sum+=entry.getValue();
}
//System.out.println("sum==>"+sum);
//System.out.println(counts);
counts.sort(new Comparator<Tuple2<Double, Long>>() {
@Override
public int compare(Tuple2<Double, Long> o1, Tuple2<Double, Long> o2) {
// TODO Auto-generated method stub
return (int)(o1._1-o2._1);
}
});
Double[] returnValues=new Double[counts.size()];
for (int i = 0; i < returnValues.length; i++) {
returnValues[i]=counts.get(i)._2/(sum*1.0)spark.mllib源码阅读-分类算法4-DecisionTree
spark.mllib源码阅读-分类算法4-DecisionTree
Java语言在Spark3.2.4集群中使用Spark MLlib库完成XGboost算法
将决策树训练分类器的模型输出保存为 Spark Scala 平台中的文本文件