用于 RandomForest 多类的 SHAP TreeExplainer:啥是 shap_values[i]?

Posted

技术标签:

【中文标题】用于 RandomForest 多类的 SHAP TreeExplainer:啥是 shap_values[i]?【英文标题】:SHAP TreeExplainer for RandomForest multiclass: what is shap_values[i]?用于 RandomForest 多类的 SHAP TreeExplainer:什么是 shap_values[i]? 【发布时间】:2021-04-09 11:21:47 【问题描述】:

我正在尝试绘制 SHAP 这是我的代码rnd_clfRandomForestClassifier

import shap 
explainer = shap.TreeExplainer(rnd_clf) 
shap_values = explainer.shap_values(X) 
shap.summary_plot(shap_values[1], X) 

我知道shap_values[0] 是负面的,shap_values[1] 是正面的。

但是对于多类 RandomForestClassifier 呢?我有rnd_clf 分类之一:

['Gusto','Kestrel 200 SCI Older Road Bike', 'Vilano Aluminum Road Bike 21 Speed Shimano', 'Fixie'].

如何确定shap_values[i] 的哪个索引对应于我的输出的哪个类?

【问题讨论】:

【参考方案1】:

如何确定 shap_values[i] 的哪个索引对应于我的输出的哪个类?

shap_values[i] 是第 i 个类的 SHAP 值。什么是第 i 类更多的是您使用的编码模式的问题:LabelEncoderpd.factorize 等。

您可以尝试以下作为线索:

from sklearn.preprocessing import LabelEncoder

labels = [
    "Gusto",
    "Kestrel 200 SCI Older Road Bike",
    "Vilano Aluminum Road Bike 21 Speed Shimano",
    "Fixie",
]
le = LabelEncoder()
y = le.fit_transform(labels)
encoding_scheme = dict(zip(y, labels))
pprint(encoding_scheme)

0: 'Fixie',
 1: 'Gusto',
 2: 'Kestrel 200 SCI Older Road Bike',
 3: 'Vilano Aluminum Road Bike 21 Speed Shimano'

所以,例如shap_values[3] 这个特殊情况是'Vilano Aluminum Road Bike 21 Speed Shimano'

为了进一步了解如何解释 SHAP 值,让我们为具有 100 个特征和 10 个类别的多类分类准备一个合成数据集:

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from shap import TreeExplainer
from shap import summary_plot

X, y = make_classification(1000, 100, n_informative=8, n_classes=10)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
print(X_train.shape)

(750, 100)

此时,我们已经有了包含 750 行、100 个特征和 10 个类的训练数据集。

让我们训练RandomForestClassifier 并将其提供给TreeExplainer

clf = RandomForestClassifier(n_estimators=100, max_depth=3)
clf.fit(X_train, y_train)
explainer = TreeExplainer(clf)
shap_values = np.array(explainer.shap_values(X_train))
print(shap_values.shape)

(10, 750, 100)

10 : 类数。所有 SHAP 值都组织成 10 个数组,每个类 1 个数组。 750:数据点的数量。每个数据点都有本地 SHAP 值。 100:特征数量。我们对每个功能都有 SHAP 值。

例如,对于Class 3,您将拥有:

print(shap_values[3].shape)

(750, 100)

750:每个数据点的 SHAP 值 100:SHAP 对每个功能的价值贡献

最后,您可以运行健全性检查以确保模型的真实预测与shap 的预测相同。

为此,我们将 (1) 交换 shap_values 的前 2 个维度,(2) 将所有特征的每个类的 SHAP 值相加,(3) 将 SHAP 值添加到基值:

shap_values_ = shap_values.transpose((1,0,2))

np.allclose(
    clf.predict_proba(X_train),
    shap_values_.sum(2) + explainer.expected_value
)

True

然后您可以继续访问summary_plot,它将根据每个类别的 SHAP 值显示特征排名。对于第 3 类,这将是:

summary_plot(shap_values[3],X_train)

解释如下:

对于第 3 类,基于 SHAP 贡献的最有影响力的特征是 16、59、24

对于特征 15,较低的值往往会导致较高的 SHAP 值(因此类标签的概率较高)

在显示的 20 个特征中,特征 50、45、48 的影响最小

【讨论】:

最后的图错了吗?里面没有特征16、59等。 @dgg32 最后的图是对的。默认情况下,仅显示前 20 个功能。但你可以改变它,

以上是关于用于 RandomForest 多类的 SHAP TreeExplainer:啥是 shap_values[i]?的主要内容,如果未能解决你的问题,请参考以下文章

H2O randomForest中的多类分类

Pipeline 和 GridSearchCV,以及 XGBoost 和 RandomForest 的多类挑战

RandomForest 和 XGB 为啥/如何?对此有啥可以做的吗?

如何使用机器学习算法设置多类?

用于多类分类的 SVM(one-vs-all)中的置信度估计

将SVM用于多类分类