java实现gbdt

Posted simple_wxl

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了java实现gbdt相关的知识,希望对你有一定的参考价值。

DATA类

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;

public class Data {
	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
	public ArrayList<ArrayList<String>> getTrainData() {
		return this.trainData;
	}

	public Data() {
		String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
		Scanner in;
		try {
			in = new Scanner(new File(dataPath));
			while (in.hasNext()) {
				String line=in.nextLine();
				String []strs=line.trim().split(",");
				ArrayList<String> tmp=new ArrayList<>();
				for(int i=0;i<strs.length;i++)
				{
					tmp.add(strs[i]);	
				}
				this.trainData.add(tmp);
			}
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		
	}

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Data d =new Data();
		
	}

}

  TREE类

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider;

public class Tree {
	private Tree leftTree=new Tree();
	private Tree rightTree=new Tree();
	private double loss=-1;
	private int attributeSplit=0;
	private String attributeSplitType="";
	boolean isLeaf;
	double leafValue;
	private ArrayList<Integer> leafNodeSet=new ArrayList<>();
	
	public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
	{
		HashSet<String> mySet=new HashSet<>();
		ArrayList<String> ans =new ArrayList<>();
		for(int i=0;i<trainData.size();i++)
		{
			mySet.add(trainData.get(i).get(idx));
		}
		
		Iterator<String> it=mySet.iterator();
		
		while(it.hasNext())
		{
			ans.add(it.next());
		}
		
		return ans;
	}
	public boolean myCmpLess(String str1,String str2)
	{
		if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
			return true;
		else return false;
		
	}
	public double computeLoss(ArrayList<Double> values)
	{
		double loss=0;
		for(int i=0;i<values.size();i++)
		{
			loss+=values.get(i);
		}
		double mean=loss/values.size();
		loss=0;
		for(int i=0;i<values.size();i++)
		{
			loss+=Math.pow(values.get(i)-mean,2);
		}
		return Math.sqrt(loss);
	}
	public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
		double ans=0;
		double sum=0,sum1=0;
		for(int i=0;i<subIdx.size();i++)
		{
			sum+=target.get(subIdx.get(i));
		}
		for(int i=0;i<subIdx.size();i++)
		{
			sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
		}
		ans=(K-1)/K*sum/sum1;
		return ans;
	}
	public double getPredictValue(Tree root)
	{
		return root.leafValue;
	}
	public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
	{
		
		if(root.isLeaf)
			return root.leafValue;
		else if(isDigit[root.attributeSplit])
		{
			if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
				return getPredictValue(root.leftTree, instance, isDigit);
			return getPredictValue(root.rightTree, instance, isDigit);
		}
		else
		{
			if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
				return getPredictValue(root.leftTree, instance, isDigit);
			return getPredictValue(root.rightTree, instance, isDigit);
		}
		
	}
	public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
	{
		
		int n=trainData.size();
		int dim=trainData.get(0).size();
		ArrayList<Integer> leftTreeIdx=new ArrayList<>();
		ArrayList<Integer> rightTreeIdx=new ArrayList<>();
		
		if(depth<maxDepth[0])
		{
			/*
			 * 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
			 * */
			double loss=-1;
			ArrayList<Integer> leftNodes=new ArrayList<>();
			ArrayList<Integer> rightNodes=new ArrayList<>();
			int attributeSplit=0;
			String attributeSplitType="";
			
			for(int i=0;i<dim;i++)//遍历所有的attribute
			{
				//得到该attribute下所有的distinct的值
				ArrayList<String> myAttributeSet=new ArrayList<>();
				ArrayList<String> subDigitAttribute=new ArrayList<>();
				myAttributeSet=getAttributeSet(trainData, i);
				if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
				{
					while(subDigitAttribute.size()<splitPoints)
					{
						Random r=new Random();
						int tmp=r.nextInt(myAttributeSet.size());
						subDigitAttribute.add(myAttributeSet.get(tmp));
						myAttributeSet.clear();
						myAttributeSet=subDigitAttribute;
					}
				}
				for(int j=0;j<myAttributeSet.size();j++)
				{
					for(int k=0;k<subIdx.size();k++)
					{
						if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
						{
							leftTreeIdx.add(subIdx.get(k));
						}
						else
						{
							rightTreeIdx.add(subIdx.get(k));
						}
					}
					ArrayList<Double> leftTarget=new ArrayList<>();
					ArrayList<Double> rightTarget=new ArrayList<>();
					for(int k=0;k<leftTreeIdx.size();k++)
						leftTarget.add(target.get(leftTreeIdx.get(k)));
					for(int k=0;k<rightTreeIdx.size();k++)
						rightTarget.add(target.get(rightTreeIdx.get(k)));
					double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);	
					if(loss<0||loss<lossTmp)
					{
						leftNodes.clear();
						rightNodes.clear();
						for(int k=0;k<leftTreeIdx.size();k++)
							leftNodes.add(leftTreeIdx.get(k));
						for(int k=0;k<rightTreeIdx.size();k++)
							rightNodes.add(rightTreeIdx.get(k));
						attributeSplit=i;
						attributeSplitType=myAttributeSet.get(j);
					}
					
				}
						
			}
			
			Tree tmpTree=new Tree();
			tmpTree.attributeSplit=attributeSplit;
			tmpTree.attributeSplitType=attributeSplitType;
			tmpTree.loss=loss;
			tmpTree.isLeaf=false;
			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
			return tmpTree;
			
		}
		else
		{
			Tree tmpTree=new Tree();
			tmpTree.isLeaf=true;
			tmpTree.leafValue=getPredictValue(K, subIdx, target);
			for(int i=0;i<subIdx.size();i++)
				tmpTree.leafNodeSet.add(subIdx.get(i));
			leafNodes.add(subIdx);
			leafValues.add(tmpTree.leafValue);
			return tmpTree;
		}
	}
	
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Tree aTree=new Tree();
	}

}

  

GBDT类

import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;


public class GBDT {
	
	private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
	private ArrayList<String> labelSets=new ArrayList<>();
	private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
	private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
	private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
	private int K;
	private Boolean isDigit[];
	private int dim;
	private int n;
	private double learningRate;
	
	private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树
	
	private int max_iter;
	private double sampleRate;
	private int maxDepth;
	private int splitPoints;

	public void computeResidual(ArrayList<Integer> subId)
	{
		for(int i=0;i<subId.size();i++)
		{
			int idx=subId.get(i);
			int y=0;
			if(this.labelTrainData.get(idx)==-1) y=0;
			else y=1;
			double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
			double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
			this.residual.get(idx).set(0, y-p1);
			this.residual.get(idx).set(1, y-p2);
		}
	}
	public ArrayList<Integer> myrandom(int maxNum,int num)
	{
		ArrayList<Integer> ans=new ArrayList<>();
		Set<Integer> mySet=new HashSet<>();
		while(mySet.size()<num)
		{
			Random r=new Random();
			int tmp=r.nextInt(maxNum);
			mySet.add(tmp);
		}
		Iterator<Integer> it=mySet.iterator();
		while(it.hasNext())
		{
			ans.add(it.next());
		}
		return ans;
	}
	
	public GBDT()
	{
		this.max_iter=50;
		this.sampleRate=0.8;
		this.K=2;//2分类问题
		this.maxDepth=6;
		this.splitPoints=3;
		this.learningRate=0.01;
		getData();
	}
	
	public void train()
	{
		for(int i=0;i<max_iter;i++)
		{
			ArrayList<Integer> subSet=new ArrayList<>();
			int numSubset=(int)(this.n*this.sampleRate);
			subSet=myrandom(this.n,numSubset);
			computeResidual(subSet);
			ArrayList<Double> target=new ArrayList<>();
			ArrayList<Tree> tmpTree=new ArrayList<>();
			int maxdepths[]={this.maxDepth};
			for(int j=0;j<this.K;j++)
			{
				target.clear();
				for(int k=0;k<subSet.size();k++)
				{
					target.add(residual.get(subSet.get(k)).get(j));
				}
				ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
				ArrayList<Double> leafValues=new ArrayList<>();
				Tree treeSub=new Tree();
				Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
				tmpTree.add(iterTree);
				updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
			}
			
			trees.add(tmpTree);
		}
	}
	
	public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
	{
		ArrayList<Integer> remainIdx=new ArrayList<>();
		int arr[]=new int[this.n];
		for(int i=0;i<this.n;i++)
			arr[i]=i;
		for(int i=0;i<subIdx.size();i++)
		{
			arr[subIdx.get(i)]=-1;
		}
		//求出不是用来训练树的余下集合
		for(int i=0;i<this.n;i++)
		{
			if(arr[i]!=-1)
				remainIdx.add(i);
		}
		for(int i=0;i<leafNodes.size();i++)
		{
			for(int j=0;j<leafNodes.get(i).size();j++)
			{
				this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
			}
		}
		for(int i=0;i<remainIdx.size();i++)
		{
			double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
			this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
		}
		
		
	}
	
	public boolean checkDigit(String str) {
		for(int i=0;i<str.length();i++)
		{
			if(!(str.charAt(i)>=‘0‘&&str.charAt(i)<=‘9‘))
			{
				return false;
			}
		}
		return true;
	}
	
	public void getData() {
		Data d =new Data();
		this.datas=d.getTrainData();
		this.dim=this.datas.get(0).size()-1;
		this.isDigit=new Boolean[this.dim];
		//遍历所有样本,去掉中间含有不是正常的数据
		for(int i=0;i<this.datas.get(0).size()-1;i++)
			labelSets.add(this.datas.get(0).get(i));
		//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
		for(int i=0;i<this.dim;i++)
		{
			if(checkDigit(this.datas.get(0).get(i)))
				this.isDigit[i]=true;
			else this.isDigit[i]=false;
		}
		//如果字符串==?说明是异常数据,这里做数据的清理
		for(int i=1;i<this.datas.size();i++)
		{
			ArrayList<String> tmp=new ArrayList<>();
			boolean flag=true;
			for(int j=0;j<this.dim;j++)
			{
				if(datas.get(i).get(j).trim().equals("?"))
				{
					flag=false;
					break;
				}
			}
			if(!flag) continue;
			if(datas.get(i).get(this.dim).trim().equals("?")) continue;
			trainData.add(tmp);
			if(datas.get(i).get(this.dim).trim().equals("<=50K")) 
				labelTrainData.add(-1);
			else
				labelTrainData.add(1);
			
		}
		this.n=this.labelTrainData.size();
		
		for(int i=0;i<this.datas.get(0).size()-1;i++)
			labelSets.add(this.datas.get(0).get(i));
		
		//初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
		for(int i=0;i<this.n;i++)
		{
			ArrayList<Double> arrTmp=new ArrayList<Double>();
			for(int j=0;j<2;j++)
			{
				arrTmp.add(0.0);
			}
			this.F.add(arrTmp);
			this.residual.add(arrTmp);
		}
		
							
	}
	
	public static void main(String[] args) {
		GBDT dGbdt=new GBDT();
		dGbdt.getData();
		System.err.println(dGbdt.n);
		
	}
}

  

 

以上是关于java实现gbdt的主要内容,如果未能解决你的问题,请参考以下文章

# Java 常用代码片段

java 代码片段

LockSupport.java 中的 FIFO 互斥代码片段

CTR预估-GBDT与LR实现

一文速学-GBDT模型算法原理以及实现+Python项目实战

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)