数据分析之美 决策树R语言实现

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了数据分析之美 决策树R语言实现相关的知识,希望对你有一定的参考价值。

数据分析之美:决策树R语言实现R语言实现决策树1 准备数据[plain] view plain copy > install packages("tree") > library(tree)

参考技术A 数据分析之美:决策树R语言实现
R语言实现决策树
1.准备数据
[plain] view plain copy
> install.packages("tree")
> library(tree)
> library(ISLR)
> attach(Carseats)
> High=ifelse(Sales<=8,"No","Yes") //set high values by sales data to calssify
> Carseats=data.frame(Carseats,High) //include the high data into the data source
> fix(Carseats)
2.生成决策树
[plain] view plain copy

> tree.carseats=tree(High~.-Sales,Carseats)
> summary(tree.carseats)

[plain] view plain copy
//output training error is 9%
Classification tree:
tree(formula = High ~ . - Sales, data = Carseats)
Variables actually used in tree construction:
[1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
[6] "Advertising" "Age" "US"
Number of terminal nodes: 27
Residual mean deviance: 0.4575 = 170.7 / 373
Misclassification error rate: 0.09 = 36 / 400
3. 显示决策树
[plain] view plain copy

> plot(tree . carseats )
> text(tree .carseats ,pretty =0)
4.Test Error

[plain] view plain copy

//prepare train data and test data
//We begin by using the sample() function to split the set of observations sample() into two halves, by selecting a random subset of 200 observations out of the original 400 observations.
> set . seed (1)
> train=sample(1:nrow(Carseats),200)
> Carseats.test=Carseats[-train,]
> High.test=High[-train]
//get the tree model with train data
> tree. carseats =tree (High~.-Sales , Carseats , subset =train )
//get the test error with tree model, train data and predict method
//predict is a generic function for predictions from the results of various model fitting functions.
> tree.pred = predict ( tree.carseats , Carseats .test ,type =" class ")
> table ( tree.pred ,High. test)
High. test
tree. pred No Yes
No 86 27
Yes 30 57
> (86+57) /200
[1] 0.715

5.决策树剪枝
[plain] view plain copy

/**
Next, we consider whether pruning the tree might lead to improved results. The function cv.tree() performs cross-validation in order to cv.tree() determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration.

For regression trees, only the default, deviance, is accepted. For classification trees, the default is deviance and the alternative is misclass (number of misclassifications or total loss).
We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.

If the tree is regression tree,
> plot(cv. boston$size ,cv. boston$dev ,type=’b ’)
*/
> set . seed (3)
> cv. carseats =cv. tree(tree .carseats ,FUN = prune . misclass ,K=10)
//The cv.tree() function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate(dev) and the value of the cost-complexity parameter used (k, which corresponds to α.
> names (cv. carseats )
[1] " size" "dev " "k" " method "
> cv. carseats
$size //the number of terminal nodes of each tree considered
[1] 19 17 14 13 9 7 3 2 1
$dev //the corresponding error rate
[1] 55 55 53 52 50 56 69 65 80
$k // the value of the cost-complexity parameter used
[1] -Inf 0.0000000 0.6666667 1.0000000 1.7500000
2.0000000 4.2500000
[8] 5.0000000 23.0000000
$method //miscalss for classification tree
[1] " misclass "
attr (," class ")
[1] " prune " "tree. sequence "

[plain] view plain copy

//plot the error rate with tree node size to see whcih node size is best
> plot(cv. carseats$size ,cv. carseats$dev ,type=’b ’)

/**
Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 9 terminal nodes results in the lowest cross-validation error rate, with 50 cross-validation errors. We plot the error rate as a function of both size and k.
*/
> prune . carseats = prune . misclass ( tree. carseats , best =9)
> plot( prune . carseats )
> text( prune .carseats , pretty =0)

//get test error again to see whether the this pruned tree perform on the test data set
> tree.pred = predict ( prune . carseats , Carseats .test , type =" class ")
> table ( tree.pred ,High. test)
High. test
tree. pred No Yes
No 94 24
Yes 22 60
> (94+60) /200
[1] 0.77

R语言实战应用精讲50篇(三十一)-R语言实现决策树(附R语言代码)

决策树回归

首先采用rpart包的rpart函数训练决策树模型,需要指定公式、数据集,将模型设定为回归模型,也就是将method设定为anova,最后是控制参数,主要是一些控制决策树生长的预剪枝参数,包括设定树的深度、叶子节点样本量、复杂度参数等,具体可以查阅函数帮助文档。

其次输出前述模型的结果,即初始树,同时输出复杂度相关表格和图形。依据这些表格和图形可以确定一个最佳的cp值,进而进行下一步后剪枝。经过后剪枝的决策树,也就是我们最后确定下来的决策树模型,用于后续的输出和预测。

最后输出决策树相关的图形,包括变量重要性条形图、树形图。

# 训练模型
# rpart参考文档
set.seed(42) # 固定交叉验证结果
fit_dt_reg <- rpart(
  form_reg, # formula
  data = traindata,
  method = "anova", # 回归
  # 回归模型无parms参数
  control = rpart.control(cp = 0.005)
)
# 原始回归树
fit_dt_reg
# 复杂度相关数据
printcp(fit_dt_reg)
plotcp(fit_dt_reg)

# 后剪枝
fit_dt_reg_pruned <- prune(fit_dt_reg, cp = cp1SE)
print(fit_dt_reg_pruned)
summary(fit_dt_reg_pruned)

# 变量重要性数值
fit_dt_reg_pruned$variable.importance
# 变量重要性图示
varimpdata <-
  data.frame(importance = fit_dt_reg_pruned$variable.importance)
g

以上是关于数据分析之美 决策树R语言实现的主要内容,如果未能解决你的问题,请参考以下文章

R语言实战应用精讲50篇(三十一)-R语言实现决策树(附R语言代码)

R语言基于R语言的数据挖掘之决策树

R语言中自编基尼系数的CART回归决策树的实现

决策树及R语言实现

R语言专题,如何使用party包构建决策树?

R语言逻辑回归(Logistic Regression)回归决策树随机森林信用卡违约分析信贷数据集|附代码数据