从python中的xgboost中提取决策规则

Posted

技术标签:

【中文标题】从python中的xgboost中提取决策规则【英文标题】:extracting decision rules from xgboost in python 【发布时间】:2019-06-25 22:53:43 【问题描述】:

我想在我即将推出的模型中使用 python 中的 xgboost。但是由于我们的生产系统是在 SAS 中,我试图从 xgboost 中提取决策规则,然后编写一个 SAS 评分代码来在 SAS 环境中实现这个模型。

我已经浏览了多个链接。以下是其中的一些:

How to extract decision rules (features splits) from xgboost model in python3?

xgboost deployment

以上两个链接对xgboost部署特别是Shiutang-Li给出的代码有很大帮助。但是,我的预测分数并不完全匹配。

以下是我目前尝试过的代码:

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.grid_search import GridSearchCV
%matplotlib inline
import graphviz
from graphviz import Digraph

#Read the sample iris data:
iris =pd.read_csv("C:\\Users\\XXXX\\Downloads\\Iris.csv")
#Create dependent variable:
iris.loc[iris["class"] != 2,"class"] = 0
iris.loc[iris["class"] == 2,"class"] = 1

#Select independent and dependent variable:
X = iris[["sepal_length","sepal_width","petal_length","petal_width"]]
Y = iris["class"]

xgdmat = xgb.DMatrix(X, Y) # Create our DMatrix to make XGBoost more efficient

#Build the sample xgboost Model:

our_params = 'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, 
             'objective': 'binary:logistic', 'max_depth':3, 'min_child_weight':1 
Base_Model = xgb.train(our_params, xgdmat, num_boost_round = 10)

#Below code reads the dump file created by xgboost and writes a scoring code in SAS:

import re
def string_parser(s):
    if len(re.findall(r":leaf=", s)) == 0:
        out  = re.findall(r"[\w.-]+", s)
        tabs = re.findall(r"[\t]+", s)
        if (out[4] == out[8]):
            missing_value_handling = (" or missing(" + out[1] + ")")
        else:
            missing_value_handling = ""

        if len(tabs) > 0:
            return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                    '        if state = ' + out[0] + ' then do;\n' +
                    re.findall(r"[\t]+", s)[0].replace('\t', '    ') +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' ) 
        else:
            return ('        if state = ' + out[0] + ' then do;\n' +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' )
    else:
        out = re.findall(r"[\w.-]+", s)
        return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        if state = ' + out[0] + ' then\n    ' +
                re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        value = value + (' + out[2] + ') ;\n')

def tree_parser(tree, i):
    return ('state = 0;\n'
             + "".join([string_parser(tree.split('\n')[i]) for i in range(len(tree.split('\n'))-1)]))

def model_to_sas(model, out_file):
    trees = model.get_dump()
    result = ["value = 0;\n"]
    with open(out_file, 'w') as the_file:
        for i in range(len(trees)):
            result.append(tree_parser(trees[i], i))
        the_file.write("".join(result))
        the_file.write("\nY_Pred1 = 1/(1+exp(-value));\n")
        the_file.write("Y_Pred0 = 1 - Y_pred1;") 

调用上述模块创建SAS评分代码:

model_to_sas(Base_Model, 'xgb_scr_code.sas')

很遗憾,我无法提供上述模块生成的完整 SAS 代码。但是,如果我们只使用一个树代码构建模型,请在下面找到 SAS 代码:

value = 0;
state = 0;
if state = 0 then
    do;
        if sepal_width < 2.95000005 or missing(sepal_width) then state = 1;
        else state = 2;
    end;
if state = 1 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 3;
        else state = 4;
    end;

if state = 3 then   value = value + (0.1586207);
if state = 4 then   value = value + (-0.127272725);
if state = 2 then
    do;
        if petal_length < 3 or missing(petal_length) then state = 5;
        else state = 6;
    end;
if state = 5 then   value = value + (-0.180952385);
if state = 6 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 7;
        else state = 8;
    end;
if state = 7 then   value = value + (0.142857149);
if state = 8 then   value = value + (-0.161290333);

Y_Pred1 = 1/(1+exp(-value));
Y_Pred0 = 1 - Y_pred1;

下面是第一棵树的转储文件输出:

booster[0]:
    0:[sepal_width<2.95000005] yes=1,no=2,missing=1
        1:[petal_length<4.75] yes=3,no=4,missing=3
            3:leaf=0.1586207
            4:leaf=-0.127272725
        2:[petal_length<3] yes=5,no=6,missing=5
            5:leaf=-0.180952385
            6:[petal_length<4.75] yes=7,no=8,missing=7
                7:leaf=0.142857149
                8:leaf=-0.161290333

所以基本上,我要做的是,将节点号保存在变量“状态”中并相应地访问叶节点(我从上述链接中提到的 Shiutang-Li 的文章中了解到)。

这是我面临的问题:

对于最多大约 40 棵树,预测分数完全匹配。例如请看下面:

案例一:

使用python预测10棵树的值:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

输出:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.4021197

使用 SAS 对 10 棵树的预测值:

Average Y predicted:  0.4021197

案例 2:

使用 python 预测 100 棵树的值:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

输出:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.33232176

使用 SAS 对 100 棵树的预测值:

Average Y predicted:  0.3323159

如您所见,100 棵树的分数并不完全匹配(最多匹配 4 个小数点)。另外,我已经在分数差异很大的大文件上尝试过这个,即分数偏差超过 10%。

谁能让我指出我的代码中的任何错误,以便分数可以完全匹配。以下是我的一些疑问:

1)我的分数计算是否正确。

2)我发现了一些与 gamma(正则化项)有关的东西。是否会影响 xgboost 使用叶值计算分数的方式。

3)转储文件给出的叶子值是否会进行四舍五入,从而产生这个问题

此外,除了解析转储文件之外,我希望有任何其他方法来完成此任务。

P.S.:我只有 SAS EG,无法访问 SAS EM 或 SAS IML。

【问题讨论】:

【参考方案1】:

我在获得匹配分数方面也有类似的经历。 我的理解是,除非您修复ntree_limit 选项以匹配您在模型拟合期间使用的n_estimators,否则评分可能会提前停止。

df['score']= xgclfpkl.predict(df[xg_features], ntree_limit=500)

在我开始使用ntree_limit 后,我开始获得匹配的分数。

【讨论】:

嗨,KKane,非常感谢您的评论,因为我现在陷入了困境。不过,我没听懂你说的。您的意思是 XGboost 提前自动停止树的数量吗?你能帮我理解这一点吗?此外,您似乎已经解决了这个问题。您能否通过在此处发布代码来帮助我。非常感谢您的帮助。 嗨,KKane,你能回复我上面的问题吗【参考方案2】:

我有类似的经历,需要将 xgboost 评分代码从 R 提取到 SAS。

最初,我遇到了与您在这里相同的问题,即在较小的树中,R 和 SAS 的分数之间没有太大差异,一旦树的数量达到 100 或更多,我就开始观察差异。

我做了 3 件事来缩小差异:

    确保丢失的组朝着正确的方向前进,您需要明确说明。否则 SAS 会将缺失值视为所有数字中的最小值。该规则应类似于 SAS 中的以下内容。

if sepal_width &gt; 2.95000005 or missing(sepal_width) then state = 1;else state = 2;if sepal_width &lt;= 2.95000005 and ~missing(sepal_width) then state = 1;else state = 2;

    我使用了一个名为 float 的 R 包来使分数有更多的小数位。 as.numeric(float::fl(Quality))

    确保 SAS 数据与您在 Python 中训练的数据具有相同的形状。

希望以上内容有所帮助。

【讨论】:

您好 DDZR,感谢您的回复。我已经处理了第 1 点和第 3 点。但是我没有理解第 2 点。我正在从转储文件中读取浮点值,如何在这里增加小数点?【参考方案3】:

我稍微考虑过将其合并到我自己的代码中。

我发现缺少处理存在一个小问题。

如果你有这样的逻辑,它似乎可以正常工作

if petal_length < 3 or missing(petal_length) then state = 5;
        else state = 6;

但是说丢失的组应该进入状态 6 而不是状态 5。然后你会得到这样的代码:

if petal_length < 3 then state = 5;
        else state = 6;

petal_length = missing (.) 在这种情况下会进入什么状态? 好吧,它仍然进入状态 5(而不是预期的状态 6),因为在 SAS 中缺失被归类为小于任何数字。

要解决此问题,您可以将所有缺失值分配给 999999999999999(选择一个大数字,因为 XGBoost 格式总是使用小于 (

missing_value_handling = (" or missing(" + out[1] + ")")

missing_value_handling = (" or " + out[1] + "=999999999999999 ")

在你的string_parser

【讨论】:

非常感谢大卫的建议。我已经做出了这些改变。但是,在我的数据中没有丢失的观察结果,问题仍然存在,即分数在小数点后 4 位后仍然不匹配。该代码正在正确读取转储文件,因此只有在转储文件本身存在小数点错误时才会发生这种情况。但是,我当然不相信会出现这种情况,因此不确定该怎么做。您能否建议可能导致这种情况的原因。非常感谢您的帮助【参考方案4】:

几个点-

首先,匹配叶返回值的正则表达式捕获转储中的“电子十进制”科学记数法(默认)。显式示例(第二个是正确的修改!)-

s = '3:leaf=9.95066429e-09'
out = re.findall(r"[\d.-]+", s)
out2 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)
out2,out

(很容易修复,但由于我的模型中恰好有一片叶子受到影响,所以无法发现!)

其次,问题是关于二进制的,但在多类目标中,转储中的每个类都有单独的树,因此您总共有 T*C 树,其中 T 是提升轮数,C是类的数量。对于 c 类(在 0,1,...,C-1 中),您需要为 i = 0,...,T-1 评估树 i*C +c(并总结终端叶子)。然后对其进行 softmax 以匹配来自 xgb 的预测。

【讨论】:

非常感谢 P.Windridge 对从转储文件中读取科学记数法进行的更正。我已经进行了必要的更改。但是,在上述示例中,没有与错误相关的科学记数法,并且问题保持不变。你能否让我知道如何解决这个问题。 PS:我的数据中没有任何缺失值。

以上是关于从python中的xgboost中提取决策规则的主要内容,如果未能解决你的问题,请参考以下文章

如何在python中提取随机森林的决策规则

如何生产 XG 提升 / 决策树 / 随机森林模型

如何从 scikit-learn 决策树中提取决策规则?

如何从 scikit-learn 决策树中提取决策规则?

如何从 scikit-learn 决策树中提取决策规则?

如何从 scikit-learn 决策树中提取决策规则?