sklearn RandomForestRegressor 显示的树值中的差异
Posted
技术标签:
【中文标题】sklearn RandomForestRegressor 显示的树值中的差异【英文标题】:sklearn RandomForestRegressor discrepancy in the displayed tree values 【发布时间】:2021-01-03 08:10:06 【问题描述】:在使用 RandomForestRegressor 时,我发现了一些奇怪的东西。为了说明问题,这里举一个小例子。我在测试数据集上应用了 RandomForestRegressor 并绘制了森林中第一棵树的图形。这给了我以下输出:
Root_node:
mse=8.64
samples=2
value=20.4
Left_leaf:
mse=0
samples=1
value=24
Right_leaf:
mse=0
samples=1
value=18
首先,我希望根节点的值为(24+18)/2=21
。但不知何故它是20.4。
但是,即使这个值是正确的,我如何获得 8.64 的 mse?
在我看来应该是:1/2[(24-20.4)^2+(18-20.4)^2]=9.36
(假设根值 20.4 是正确的)
我的解决方案是:1/2[(24-21)^2+(18-21)^2]=9
。如果我只使用 DecisionTreeRegressor,这也是我得到的。
RandomForestRegressor 的实现有问题还是我完全错了?
这是我的可重现代码:
import pandas as pd
from sklearn import tree
from sklearn.ensemble import RandomForestRegressor
import graphviz
# create example dataset
data = 'AGE': [91, 42, 29, 94, 85], 'TAX': [384, 223, 280, 666, 384], 'Y': [19, 21, 24, 13, 18]
df = pd.DataFrame(data=data)
x = df[['AGE','TAX']]
y = df[['Y']]
rf_reg = RandomForestRegressor(max_depth=2, random_state=1)
rf_reg.fit(x,y)
# plot a single tree of forest
dot_data = tree.export_graphviz(rf_reg.estimators_[0], out_file=None, feature_names=x.columns)
graph = graphviz.Source(dot_data)
graph
和输出图:
【问题讨论】:
【参考方案1】:tl;dr
这是由于引导采样。
详细说明:
使用默认设置bootstrap=True
,RF 将在构建单个树时使用引导采样;引用交叉验证线程Number of Samples per-Tree in a Random Forest:
如果
bootstrap=True
,那么对于每棵树,从训练集中随机抽取 N 个带放回的样本,并且树是在这个新版本的训练数据上构建的。这在训练过程中引入了随机性,因为每棵树都将在略有不同的训练集上进行训练。可以预期,从大小为 N 的数据集中抽取 N 个带有替换的样本将从原始集合中选择约 2/3 的唯一样本。
“With replacement”表示有些样本可能会被多次选择,而另一些则会被遗漏,剩下的被选择样本总数等于原始数据集的样本数(这里是 5 个)。
您显示的树中实际发生的情况是,尽管 Graphviz 显示 samples=2
,但这应该被理解为 唯一 样本的数量;根节点中总共有 5 (bootstrap) 样本:样本的 2 个副本为 y=24
,样本的 3 个副本为 y=18
(回想一下,根据 bootstrap 采样的定义过程中,这里的根节点必须包含 5 个样本,不多也不少)。
现在显示的值加起来:
# value:
(2*24 + 3*18)/5
# 20.4
# mse:
(2*(24-20.4)**2 + 3*(18-20.4)**2)/5
# 8.64
显然似乎有一些设计选择,无论是在 Graphviz 可视化中还是在底层 DecisionTreeRegressor
中,以便只存储/显示 unique 样本的数量,这可能(或可能不是)是打开 Github 问题的原因,但这就是目前的情况(老实说,我不确定自己是否希望在此处显示实际的 total 样本数,包括由于自举抽样而产生的重复)。
【讨论】:
以上是关于sklearn RandomForestRegressor 显示的树值中的差异的主要内容,如果未能解决你的问题,请参考以下文章
无法从 sklearn.externals.joblib 导入 Sklearn