你能从 scikit-learn 中的 DecisionTreeRegressor 中获取选定的叶子吗
Posted
技术标签:
【中文标题】你能从 scikit-learn 中的 DecisionTreeRegressor 中获取选定的叶子吗【英文标题】:Can you get the selected leaf from a DecisionTreeRegressor in scikit-learn 【发布时间】:2015-02-23 14:50:46 【问题描述】:只是阅读此great paper 并尝试实现此:
...我们对待每一个人 树作为一个分类特征,将 一个实例最终落入的叶子的索引。我们使用 1- 这类特征的of-K编码。例如,考虑 图 1 中具有 2 个子树的提升树模型,其中 第一个子树有 3 个叶子,第二个有 2 个叶子。如果 实例在第一个子树的叶子 2 和叶子 1 中结束 第二个子树,线性分类器的整体输入将 是二进制向量 [0, 1, 0, 1, 0],其中前 3 个条目 对应于第一个子树的叶子和最后2个 第二个子树的那些...
有谁知道我如何预测一堆行并为每一行获取集合中每棵树的选定叶子?对于这个用例,我并不关心节点代表什么,只关心它的索引。查看了源代码,我无法很快看到任何明显的东西。我可以看到我需要迭代树并执行以下操作:
for sample in X_test:
for tree in gbc.estimators_:
leaf = tree.leaf_index(sample) # This is the function I need but don't think exists.
...
任何指针表示赞赏。
【问题讨论】:
【参考方案1】:以下功能超越了从决策树中识别选定的叶子,并实现了参考论文中的应用程序。它的用法与参考论文相同,我使用 GBC 进行特征工程。
def makeTreeBins(gbc, X):
'''
Takes in a GradientBoostingClassifier object (gbc) and a data frame (X).
Returns a numpy array of dim (rows(X), num_estimators), where each row represents the set of terminal nodes
that the record X[i] falls into across all estimators in the GBC.
Note, each tree produces 2^max_depth terminal nodes. I append a prefix to the terminal node id in each incremental
estimator so that I can use these as feature ids in other classifiers.
'''
for i, dt_i in enumerate(gbc.estimators_):
prefix = (i + 2)*100 #Must be an integer
nds = prefix + dt_i[0].tree_.apply(np.array(X).astype(np.float32))
if i == 0:
nd_mat = nds.reshape(len(nds), 1)
else:
nd_mat = np.hstack((nd, nds.reshape(len(nds), 1)))
return nd_mat
【讨论】:
这是怎么回答问题的? 它展示了如何为梯度提升分类器中的每棵树获取数据帧中每条记录的叶节点。它专门解决了如何实现参考论文中的方法。 好的,谢谢。尽量避免像我有同样问题的词,直接尝试给出答案,否则有些人会混淆并标记答案 nd_mat = np.hstack((nd, nds.reshape(len(nds), 1))) 应该是 np.hstack((nd_map, nds.reshape(len(nds), 1) ))【参考方案2】:DecisionTreeRegressor 具有tree_
属性,可让您访问底层决策树。它有方法apply
,貌似找到了对应的叶子id:
dt.tree_.apply(X)
请注意,apply
期望其输入的类型为 float32
。
【讨论】:
是的,我知道如何遍历节点,但是如何确定终端叶而无需遍历每个节点并检查阈值,事实上这非常困难,因为 GBC 确实具有子样本,对吗?因此,我还必须考虑该树的选定功能。 @gatapia,我更新了我的答案,原来sklearn中有这样的功能。以上是关于你能从 scikit-learn 中的 DecisionTreeRegressor 中获取选定的叶子吗的主要内容,如果未能解决你的问题,请参考以下文章
你能从部署在 GitHub 上的 heroku 应用程序写入 JSON 文件吗
discord.py 你能从 discord 标签中获取用户对象吗?