决策树
Posted 这个签名很没水平
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了决策树相关的知识,希望对你有一定的参考价值。
package TreeStructure; import java.util.ArrayList; import java.util.List; public class testClass { public static void main(String[] args) { double [][]exercise = {{1,1,0,0},{1,3,1,1},{3,2,0,0},{3,2,1,10},{3,2,1,10},{3,2,1,10},{2,2,1,1},{3,2,1,9},{2,3,0,1},{2,1,0,0},{3,2,0,1},{2,1,0,1},{1,1,0,1}}; String []Attribute = {"weather","thin","cloth","target"}; int []index = {1,0,2,3}; double [][]exerciseData = new double[exercise.length][]; for(int i = 0;i<exerciseData.length;i++){ exerciseData[i] = new double[exercise[i].length]; for(int j = 0;j<exerciseData[i].length;j++){ exerciseData[i][j] = exercise[i][index[j]]; } } for(int i = 0;i<exerciseData.length;i++){ for(int j = 0;j<exerciseData[i].length;j++){ System.out.print(" "+exerciseData[i][j]); } System.out.println(); } DecisionTree dt = new DecisionTree(); List<ArrayList<String>> data = new ArrayList<ArrayList<String>>(); for(int i=0;i<exerciseData.length;i++){ ArrayList<String> t = new ArrayList<String>(); for(int j=0;j<exerciseData[i].length;j++){ t.add(exerciseData[i][j]+""); } data.add(t); } List<String>attribute = new ArrayList<String>(); for(int k=0;k<Attribute.length;k++){ attribute.add(Attribute[k]); } TreeNode n =null; TreeNode node = dt.createDT(data,attribute,n); double[]dataExercise = {2,3}; List list = new ArrayList(); for(int i = 0;i<dataExercise.length;i++){ list.add(dataExercise[i]); } node.traverse(list); System.out.println(); } }
package TreeStructure; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; public class DecisionTree { public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList,TreeNode node){ System.out.println("当前的DATA为"); for(int i=0;i<data.size();i++){ ArrayList<String> temp = data.get(i); for(int j=0;j<temp.size();j++){ System.out.print(temp.get(j)+ " "); } System.out.println(); } System.out.println("---------------------------------"); System.out.println("当前的ATTR为"); for(int i=0;i<attributeList.size();i++){ System.out.print(attributeList.get(i)+ " "); } System.out.println(); System.out.println("---------------------------------"); //String result = InfoGain.IsPure(InfoGain.getTarget(data)); //System.out.println("***************"+result); if(node==null){ node = new TreeNode(); node.setAttributeValue("start"); node.setNodeName("start"); } if(attributeList.size() == 1){ int num = data.size(); for(int i = 0;i<num;i++){ TreeNode leafNode = new TreeNode(); leafNode.setAttributeValue(data.get(i).get(0)); leafNode.setNodeName("target"); node.getChildTreeNode().add(leafNode); } return node; }else{ System.out.println("选择出的最大增益率属性为: " + attributeList.get(0)); //node.setAttributeValue(attributeList.get(0)); List<ArrayList<String>> resultData = null; InfoGain gain = new InfoGain(data,attributeList); Map<String,Long> attrvalueMap = gain.getAttributeValue(0); for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){ resultData = gain.getData4Value(entry.getKey(), 0); TreeNode leafNode = new TreeNode(); leafNode.setAttributeValue(entry.getKey()); leafNode.setNodeName(attributeList.get(0)); node.getChildTreeNode().add(leafNode); System.out.println("当前为"+attributeList.get(0)+"的"+entry.getKey()+"分支。"); for (int j = 0; j < resultData.size(); j++) { resultData.get(j).remove(0); } ArrayList<String> resultAttr = new ArrayList<String>(attributeList); resultAttr.remove(0); createDT(resultData,resultAttr,leafNode); } } return node; } }
package TreeStructure; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; public class InfoGain { private List<ArrayList<String>> data; private List<String> attribute; public InfoGain(List<ArrayList<String>> data,List<String> attribute){ this.data = new ArrayList<ArrayList<String>>(); for(int i=0;i<data.size();i++){ List<String> temp = data.get(i); ArrayList<String> t = new ArrayList<String>(); for(int j=0;j<temp.size();j++){ t.add(temp.get(j)); } this.data.add(t); } this.attribute = new ArrayList<String>(); for(int k=0;k<attribute.size();k++){ this.attribute.add(attribute.get(k)); } /*this.data = data; this.attribute = attribute;*/ } public Map<String,Long> getAttributeValue(int attributeIndex){ Map<String,Long> attributeValueMap = new HashMap<String,Long>(); for(ArrayList<String> note : data){ String key = note.get(attributeIndex); Long value = attributeValueMap.get(key); attributeValueMap.put(key, value != null ? ++value :1L); } return attributeValueMap; } public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){ List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>(); Iterator<ArrayList<String>> iterator = data.iterator(); for(;iterator.hasNext();){ ArrayList<String> templist = iterator.next(); if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){ ArrayList<String> temp = (ArrayList<String>) templist.clone(); resultData.add(temp); } } return resultData; } public static List<String> getTarget(List<ArrayList<String>> data){ List<String> list = new ArrayList<String>(); for(ArrayList<String> temp : data){ int index = temp.size()-1 ; if(index == -1){ break; } String value = temp.get(index); list.add(value); } return list; } //判断当前纯度是否100% public static String IsPure(List<String> list){ return list.get(0); } }
package TreeStructure; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; class TreeNode{ private String attributeValue; private List<TreeNode> childTreeNode; private List<String> pathName; private String targetFunValue; private String nodeName; public TreeNode(String nodeName){ this.nodeName = nodeName; this.childTreeNode = new ArrayList<TreeNode>(); this.pathName = new ArrayList<String>(); } public TreeNode(){ this.childTreeNode = new ArrayList<TreeNode>(); this.pathName = new ArrayList<String>(); } public String getAttributeValue() { return attributeValue; } public void setAttributeValue(String attributeValue) { this.attributeValue = attributeValue; } public List<TreeNode> getChildTreeNode() { return childTreeNode; } public void setChildTreeNode(List<TreeNode> childTreeNode) { this.childTreeNode = childTreeNode; } public String getTargetFunValue() { return targetFunValue; } public void setTargetFunValue(String targetFunValue) { this.targetFunValue = targetFunValue; } public String getNodeName() { return nodeName; } public void setNodeName(String nodeName) { this.nodeName = nodeName; } public List<String> getPathName() { return pathName; } public void setPathName(List<String> pathName) { this.pathName = pathName; } public void traverse() { System.out.println(this.getNodeName()+": "+this.getAttributeValue()); int childNumber = this.childTreeNode.size(); System.out.println(childNumber); for (int i = 0; i < childNumber; i++) { TreeNode child = this.childTreeNode.get(i); child.traverse(); } } public List getTarget(TreeNode node){ List a = new ArrayList();; int childNum = node.getChildTreeNode().size(); if(node.childTreeNode.get(0).childTreeNode.size()==0){//表示node孩子的孩子为空,即node下一层为目标层 for(int i = 0;i<childNum;i++){ a.add(node.getChildTreeNode().get(i).getAttributeValue()); } }else{ for(int i = 0;i<childNum;i++){ a.addAll(getTarget(node.getChildTreeNode().get(i))); } } return a; } public void traverse(List list) { if(list.size()==0){ List target = getTarget(this); // int childlistNumber = this.childTreeNode.size(); // List a = new ArrayList(); // for(int i = 0;i<childlistNumber;i++){ // TreeNode child = this.childTreeNode.get(i); // a.add(child.getAttributeValue()); // } List b = new ArrayList(); // Map result = new HashMap(); for(int i = 0;i<target.size();i++){ if(!b.contains(target.get(i))){ b.add(target.get(i)); } } int []count = new int [b.size()]; for(int i = 0;i<b.size();i++){ for(int j = 0;j<target.size();j++){ if(b.get(i).equals(target.get(j))){ count[i] = count[i]+1; } } System.out.println(b.get(i)+"的数量是: "+count[i]); } int maxIndex = 0; for(int i = 1;i<count.length;i++){ if(count[maxIndex]<count[i]){ maxIndex = i; } } System.out.println("选择"+b.get(maxIndex)+"为最终决策"); }else{ List a = new ArrayList(); double temp = (Double)list.get(0); int childlistNumber = this.childTreeNode.size(); System.out.println(childlistNumber); for(int i = 0;i<childlistNumber;i++){ TreeNode child = this.childTreeNode.get(i); double tempchild = Double.valueOf(child.getAttributeValue()); if(temp==tempchild){ System.out.println(child.getNodeName()+": "+child.getAttributeValue()); list.remove(0); child.traverse(list); } } } } }
以上是关于决策树的主要内容,如果未能解决你的问题,请参考以下文章
sklearn决策树算法DecisionTreeClassifier(API)的使用以及决策树代码实例 - 莺尾花分类