决策树

Posted 这个签名很没水平

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了决策树相关的知识,希望对你有一定的参考价值。

package TreeStructure;

import java.util.ArrayList;
import java.util.List;

public class testClass {
    public static void main(String[] args) {
        double [][]exercise = {{1,1,0,0},{1,3,1,1},{3,2,0,0},{3,2,1,10},{3,2,1,10},{3,2,1,10},{2,2,1,1},{3,2,1,9},{2,3,0,1},{2,1,0,0},{3,2,0,1},{2,1,0,1},{1,1,0,1}};
        String []Attribute = {"weather","thin","cloth","target"};
        int []index = {1,0,2,3};
        double [][]exerciseData = new double[exercise.length][];
        for(int i = 0;i<exerciseData.length;i++){
            exerciseData[i] = new double[exercise[i].length];
            for(int j = 0;j<exerciseData[i].length;j++){
                exerciseData[i][j] = exercise[i][index[j]];
            }
        }
        
        
        for(int i = 0;i<exerciseData.length;i++){
            for(int j = 0;j<exerciseData[i].length;j++){
                System.out.print("  "+exerciseData[i][j]);
            }
            System.out.println();
        }
        
        DecisionTree dt = new DecisionTree();
        List<ArrayList<String>> data = new ArrayList<ArrayList<String>>();
        for(int i=0;i<exerciseData.length;i++){
            ArrayList<String> t = new ArrayList<String>();
            for(int j=0;j<exerciseData[i].length;j++){
                t.add(exerciseData[i][j]+"");
            }
            data.add(t);
        }
        
        List<String>attribute = new ArrayList<String>();
        for(int k=0;k<Attribute.length;k++){
            attribute.add(Attribute[k]);
        }
        TreeNode n =null;
        TreeNode node = dt.createDT(data,attribute,n);
        double[]dataExercise = {2,3};
        List list = new ArrayList();
        for(int i = 0;i<dataExercise.length;i++){
            list.add(dataExercise[i]);
        }
        
        node.traverse(list);
        
        System.out.println();
    }
    
}
package TreeStructure;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class DecisionTree {
    
    public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList,TreeNode node){
        
        System.out.println("当前的DATA为");
         for(int i=0;i<data.size();i++){
                ArrayList<String> temp = data.get(i);
                for(int j=0;j<temp.size();j++){
                    System.out.print(temp.get(j)+ " ");
                }
                System.out.println();
            }
            System.out.println("---------------------------------");
            System.out.println("当前的ATTR为");
            for(int i=0;i<attributeList.size();i++){
                System.out.print(attributeList.get(i)+ " ");
            }
            System.out.println();
            System.out.println("---------------------------------");
            //String result = InfoGain.IsPure(InfoGain.getTarget(data));
            //System.out.println("***************"+result);
            
            if(node==null){
                node = new TreeNode();
                node.setAttributeValue("start");
                node.setNodeName("start");
                
            }
            
            if(attributeList.size() == 1){
                
                int num = data.size();
                for(int i = 0;i<num;i++){
                TreeNode leafNode = new TreeNode();
                leafNode.setAttributeValue(data.get(i).get(0));
                leafNode.setNodeName("target");
                node.getChildTreeNode().add(leafNode);
                }
                return node;
                
            }else{
                
                System.out.println("选择出的最大增益率属性为: " + attributeList.get(0));
                //node.setAttributeValue(attributeList.get(0));
                List<ArrayList<String>> resultData = null;
                InfoGain gain = new InfoGain(data,attributeList);
                
                Map<String,Long> attrvalueMap = gain.getAttributeValue(0);
                
                for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                    resultData = gain.getData4Value(entry.getKey(), 0);
                    TreeNode leafNode = new TreeNode();
                    leafNode.setAttributeValue(entry.getKey());
                    leafNode.setNodeName(attributeList.get(0));
                    
                    node.getChildTreeNode().add(leafNode);
                    
                    System.out.println("当前为"+attributeList.get(0)+"的"+entry.getKey()+"分支。");
                    for (int j = 0; j < resultData.size(); j++) {
                        resultData.get(j).remove(0);
                    }
                    ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                    resultAttr.remove(0);
                    createDT(resultData,resultAttr,leafNode);            
                    }
            }
           
            return node;
            }
        }
            
            
            
            
            
            
            
            
            
    
    

    
package TreeStructure;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class InfoGain {
    private List<ArrayList<String>> data;
    private List<String> attribute;
    
    
public InfoGain(List<ArrayList<String>> data,List<String> attribute){
        
        this.data = new ArrayList<ArrayList<String>>();
        for(int i=0;i<data.size();i++){
            List<String> temp = data.get(i);
            ArrayList<String> t = new ArrayList<String>();
            for(int j=0;j<temp.size();j++){
                t.add(temp.get(j));
            }
            this.data.add(t);
        }
        
        this.attribute = new ArrayList<String>();
        for(int k=0;k<attribute.size();k++){
            this.attribute.add(attribute.get(k));
        }
        /*this.data = data;
        this.attribute = attribute;*/
    }
public  Map<String,Long> getAttributeValue(int attributeIndex){
        
        Map<String,Long> attributeValueMap = new HashMap<String,Long>();
        for(ArrayList<String> note : data){
            String key = note.get(attributeIndex);
            Long value = attributeValueMap.get(key);
            attributeValueMap.put(key, value != null ? ++value :1L);
        }
        return attributeValueMap;
        
    }
    
    public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
        
        List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
        Iterator<ArrayList<String>> iterator = data.iterator();
        for(;iterator.hasNext();){
            ArrayList<String> templist = iterator.next();
            if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                ArrayList<String> temp = (ArrayList<String>) templist.clone();
                resultData.add(temp);
            }
        }
        return resultData;
    }
public static List<String> getTarget(List<ArrayList<String>> data){
        
        List<String> list = new ArrayList<String>();
        for(ArrayList<String> temp : data){
            int index = temp.size()-1 ;
            if(index == -1){
                break;
            }
            String value = temp.get(index);
            list.add(value);
        }
        return list;
    }
    
    //判断当前纯度是否100%
    public static String IsPure(List<String> list){
        
       
        
        return list.get(0);
    }
    
}
package TreeStructure;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


 class TreeNode{

private String attributeValue;
        private List<TreeNode> childTreeNode;
        private List<String> pathName;
        private String targetFunValue;
        private String nodeName;
        
        public TreeNode(String nodeName){
            
            this.nodeName = nodeName;
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }
        
        public TreeNode(){
            this.childTreeNode = new ArrayList<TreeNode>();
            this.pathName = new ArrayList<String>();
        }

        public String getAttributeValue() {
            return attributeValue;
        }

        public void setAttributeValue(String attributeValue) {
            this.attributeValue = attributeValue;
        }

        public List<TreeNode> getChildTreeNode() {
            return childTreeNode;
        }

        public void setChildTreeNode(List<TreeNode> childTreeNode) {
            this.childTreeNode = childTreeNode;
        }

        public String getTargetFunValue() {
            return targetFunValue;
        }

        public void setTargetFunValue(String targetFunValue) {
            this.targetFunValue = targetFunValue;
        }

        public String getNodeName() {
            return nodeName;
        }

        public void setNodeName(String nodeName) {
            this.nodeName = nodeName;
        }

        public List<String> getPathName() {
            return pathName;
        }

        public void setPathName(List<String> pathName) {
            this.pathName = pathName;
        }
        
        public void traverse() {  
            System.out.println(this.getNodeName()+":   "+this.getAttributeValue());
            int childNumber = this.childTreeNode.size(); 
            System.out.println(childNumber);
            for (int i = 0; i < childNumber; i++) {  
                TreeNode child = this.childTreeNode.get(i);  
                child.traverse();  
            }  
        }  
        
        
        public List getTarget(TreeNode node){
            List a = new ArrayList();;
            int childNum = node.getChildTreeNode().size();
            if(node.childTreeNode.get(0).childTreeNode.size()==0){//表示node孩子的孩子为空,即node下一层为目标层
                for(int i = 0;i<childNum;i++){
                    a.add(node.getChildTreeNode().get(i).getAttributeValue());
                    
                }
                
            }else{
                for(int i = 0;i<childNum;i++){
                    a.addAll(getTarget(node.getChildTreeNode().get(i)));
                }
            }
            return a;
        }
        public void traverse(List list) {
            if(list.size()==0){
                List target = getTarget(this);
//                int childlistNumber = this.childTreeNode.size(); 
//                List a = new ArrayList();
//                for(int i = 0;i<childlistNumber;i++){
//                TreeNode child = this.childTreeNode.get(i);
//                a.add(child.getAttributeValue());
//                }
                List b = new ArrayList();
//                Map result = new HashMap();
                for(int i = 0;i<target.size();i++){
                    if(!b.contains(target.get(i))){
                    b.add(target.get(i));
                    }
                }
                int []count = new int [b.size()];
                for(int i = 0;i<b.size();i++){
                    
                    for(int j = 0;j<target.size();j++){
                        if(b.get(i).equals(target.get(j))){
                            count[i] = count[i]+1;
                        }
                    }
                    System.out.println(b.get(i)+"的数量是:   "+count[i]);
                }
                int maxIndex = 0;
                for(int i = 1;i<count.length;i++){
                    if(count[maxIndex]<count[i]){
                        maxIndex = i;
                    }
                }
                System.out.println("选择"+b.get(maxIndex)+"为最终决策");
                
                
                
                
            }else{
            List a = new ArrayList();
            double temp = (Double)list.get(0);
            int childlistNumber = this.childTreeNode.size(); 
            System.out.println(childlistNumber);
            for(int i = 0;i<childlistNumber;i++){
                TreeNode child = this.childTreeNode.get(i);  
                double tempchild = Double.valueOf(child.getAttributeValue());
                if(temp==tempchild){
                    System.out.println(child.getNodeName()+":   "+child.getAttributeValue());
                    list.remove(0);
                    child.traverse(list);
                }
            }
            }
        }
 }
        
    
 

 

以上是关于决策树的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Python 中绘制回归树

sklearn决策树算法DecisionTreeClassifier(API)的使用以及决策树代码实例 - 莺尾花分类

机器学习:通俗易懂决策树与随机森林及代码实践

决策树的几种类型差异及Spark 2.0-MLlibScikit代码分析

机器学习_决策树Python代码详解

Chapter3 绘制决策树