Spark mllib 决策树的统计信息

Posted

技术标签:

【中文标题】Spark mllib 决策树的统计信息【英文标题】:Statistics for Spark mllib DecisionTree 【发布时间】:2016-10-07 03:25:02 【问题描述】:

在学习了一个 mllib 决策树模型 (http://spark.apache.org/docs/latest/mllib-decision-tree.html) 后,我如何计算节点统计信息,例如支持度(有多少样本匹配此子树)以及每个标签有多少样本匹配此子树?

如果它更容易,我也很乐意使用 Spark 以外的任何其他工具来获取调试字符串并计算这些统计信息。调试字符串示例:

DecisionTreeModel classifier of depth 20 with 20031 nodes
  If (feature 0 <= -35.0)
   If (feature 24 <= 176.0)
    If (feature 0 <= -200.0)
     If (feature 29 <= 109.0)
      If (feature 6 <= -156.0)
       If (feature 9 <= 0.0)
        If (feature 20 <= -116.0)
         If (feature 16 <= 203.0)
          If (feature 11 <= 163.0)
           If (feature 5 <= 384.0)
            If (feature 15 <= 325.0)
             If (feature 13 <= -248.0)
              If (feature 20 <= -146.0)
               Predict: 0.0
              Else (feature 20 > -146.0)
               If (feature 19 <= -58.0)
                Predict: 6.0
               Else (feature 19 > -58.0)
                Predict: 0.0
             Else (feature 13 > -248.0)
              If (feature 9 <= -26.0)
               Predict: 0.0
              Else (feature 9 > -26.0)
               If (feature 10 <= 218.0)
...

我使用 mllib 是因为我需要进行核外学习,因为数据不适合内存。如果您有比 mllib 更好的替代品,我很乐意尝试一下。

【问题讨论】:

【参考方案1】:

我使用sklearn 作为算法来创建我的模型,并与 Spark Context 集成以产生这样的输出:

if ( device_type_id <= 1 )
    39 Clicks - 0.61%
    2135 Conversions - 33.32% 
else ( device_type_id > 1 )
    if ( country_id <= 216 )
        1097 Clicks - 17.12%
    else ( country_id > 216 )
        if ( browser_id <= 2 )
            296 Clicks - 4.62%
        else ( browser_id > 2 )
            if ( browser_id <= 4 )
                if ( browser_id <= 3 )
                    if ( operating_system_id <= 2 )
                        262 Clicks - 4.09%

这是我用来显示这样一棵树的代码:

def get_code(count_df, tree, feature_names, target_names, spacer_base="    "):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value
    temp_list = []
    res_count = count_df
    def recurse(res_count, temp_list, left, right, threshold, features, node, depth):
        spacer = spacer_base * depth
        if (threshold[node] != -2):
            temp_list.append("if ( " + features[node] + " <= " + \
                str(int(round(threshold[node] - 1))) + " )")
            if left[node] != -1:
                    recurse (res_count, temp_list, left, right, threshold, features, left[node], depth+1)
            temp_list.append("else ( " + features[node] + " > " + \
                str(int(round(threshold[node] - 1))) + " )")
            if right[node] != -1:
                    recurse (res_count, temp_list, left, right, threshold, features, right[node], depth+1)

        else:
            target = value[node]
            for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
                target_name = target_names[i]
                target_count = int(v)
                temp_list.append(str(target_count) +" "+ str(target_name) + " - " + str(round((target_count / res_count), 4) * 100)+ "%")

    recurse(res_count, temp_list, left, right, threshold, features, 0, 0)
    return temp_list

否则,请参考我的帖子here中提供的答案,但是它是用Scala写的,改变了Spark生成决策树的方式。

【讨论】:

我不能使用 sklearn 决策树,因为它们不支持在线/非核心训练。但是你得到的输出看起来可能是我想要的(你有两个标签,点击和转化是对的吗?)。你能提供一些代码来获得这个输出吗?我也可以从 spark mllib 模型中获取它吗?

以上是关于Spark mllib 决策树的统计信息的主要内容,如果未能解决你的问题,请参考以下文章

spark.mllib源码阅读-分类算法4-DecisionTree

Apache Spark:Mllib之决策树的操作(java)

Spark-Mllib基本统计

spark MLLib的基础统计部分学习

[机器学习Spark]Spark MLlib实现数据基本统计

spark MLlib实现的基于朴素贝叶斯(NaiveBayes)的中文文本自动分类