Spark mllib 决策树的统计信息
Posted
技术标签:
【中文标题】Spark mllib 决策树的统计信息【英文标题】:Statistics for Spark mllib DecisionTree 【发布时间】:2016-10-07 03:25:02 【问题描述】:在学习了一个 mllib 决策树模型 (http://spark.apache.org/docs/latest/mllib-decision-tree.html) 后,我如何计算节点统计信息,例如支持度(有多少样本匹配此子树)以及每个标签有多少样本匹配此子树?
如果它更容易,我也很乐意使用 Spark 以外的任何其他工具来获取调试字符串并计算这些统计信息。调试字符串示例:
DecisionTreeModel classifier of depth 20 with 20031 nodes
If (feature 0 <= -35.0)
If (feature 24 <= 176.0)
If (feature 0 <= -200.0)
If (feature 29 <= 109.0)
If (feature 6 <= -156.0)
If (feature 9 <= 0.0)
If (feature 20 <= -116.0)
If (feature 16 <= 203.0)
If (feature 11 <= 163.0)
If (feature 5 <= 384.0)
If (feature 15 <= 325.0)
If (feature 13 <= -248.0)
If (feature 20 <= -146.0)
Predict: 0.0
Else (feature 20 > -146.0)
If (feature 19 <= -58.0)
Predict: 6.0
Else (feature 19 > -58.0)
Predict: 0.0
Else (feature 13 > -248.0)
If (feature 9 <= -26.0)
Predict: 0.0
Else (feature 9 > -26.0)
If (feature 10 <= 218.0)
...
我使用 mllib 是因为我需要进行核外学习,因为数据不适合内存。如果您有比 mllib 更好的替代品,我很乐意尝试一下。
【问题讨论】:
【参考方案1】:我使用sklearn 作为算法来创建我的模型,并与 Spark Context 集成以产生这样的输出:
if ( device_type_id <= 1 )
39 Clicks - 0.61%
2135 Conversions - 33.32%
else ( device_type_id > 1 )
if ( country_id <= 216 )
1097 Clicks - 17.12%
else ( country_id > 216 )
if ( browser_id <= 2 )
296 Clicks - 4.62%
else ( browser_id > 2 )
if ( browser_id <= 4 )
if ( browser_id <= 3 )
if ( operating_system_id <= 2 )
262 Clicks - 4.09%
这是我用来显示这样一棵树的代码:
def get_code(count_df, tree, feature_names, target_names, spacer_base=" "):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
temp_list = []
res_count = count_df
def recurse(res_count, temp_list, left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
temp_list.append("if ( " + features[node] + " <= " + \
str(int(round(threshold[node] - 1))) + " )")
if left[node] != -1:
recurse (res_count, temp_list, left, right, threshold, features, left[node], depth+1)
temp_list.append("else ( " + features[node] + " > " + \
str(int(round(threshold[node] - 1))) + " )")
if right[node] != -1:
recurse (res_count, temp_list, left, right, threshold, features, right[node], depth+1)
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
temp_list.append(str(target_count) +" "+ str(target_name) + " - " + str(round((target_count / res_count), 4) * 100)+ "%")
recurse(res_count, temp_list, left, right, threshold, features, 0, 0)
return temp_list
否则,请参考我的帖子here中提供的答案,但是它是用Scala
写的,改变了Spark
生成决策树的方式。
【讨论】:
我不能使用 sklearn 决策树,因为它们不支持在线/非核心训练。但是你得到的输出看起来可能是我想要的(你有两个标签,点击和转化是对的吗?)。你能提供一些代码来获得这个输出吗?我也可以从 spark mllib 模型中获取它吗?以上是关于Spark mllib 决策树的统计信息的主要内容,如果未能解决你的问题,请参考以下文章
spark.mllib源码阅读-分类算法4-DecisionTree
Apache Spark:Mllib之决策树的操作(java)