使用排列重要性绘制前 n 个特征
Posted
技术标签:
【中文标题】使用排列重要性绘制前 n 个特征【英文标题】:Plotting top n features using permutation importance 【发布时间】:2021-11-13 14:50:09 【问题描述】:import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
result = permutation_importance(rf,
X_test,
y_test,
n_repeats=10,
random_state=42,
n_jobs=2)
sorted_idx = result.importances_mean.argsort()
fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
vert=False,
labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()
在上面的代码中,取自文档中的this example,有没有办法只绘制前 3 个特征而不是所有特征?
【问题讨论】:
这实际上与 scikit-learn 无关,并且有太多样板代码无法到达定义rf
、X_test
和 y_test
的地步。即便如此,这里的任何代码都不重要,因为这个问题归结为“我如何从列表中获取最后一个 n
元素”,这是 here 的答案,可能还有很多其他地方。但是,因为有赏金,所以问题仍然悬而未决。
【参考方案1】:
argsort
“返回对数组进行排序的索引”,因此这里sorted_idx
包含按最重要到最重要的顺序排列的特征索引。由于您只需要 3 个最重要的特征,因此只取最后 3 个索引:
sorted_idx = result.importances_mean.argsort()[-3:]
# array([4, 0, 1])
那么绘图代码可以保持原样,但现在它只会绘制前 3 个特征:
# unchanged
fig, ax = plt.subplots(figsize=(6, 3))
ax.boxplot(result.importances[sorted_idx].T,
vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()
请注意,如果您希望保持 sorted_idx
不变(例如,在代码的其他地方使用完整索引),
将sorted_idx
更改为sorted_idx[-3:]
内联:
sorted_idx = result.importances_mean.argsort() # unchanged
ax.boxplot(result.importances[sorted_idx[-3:]].T, # replace sorted_idx with sorted_idx[-3:]
vert=False, labels=X_test.columns[sorted_idx[-3:]]) # replace sorted_idx with sorted_idx[-3:]
或将过滤后的索引存储在单独的变量中:
sorted_idx = result.importances_mean.argsort() # unchanged
top3_idx = sorted_idx[-3:] # store top 3 indices
ax.boxplot(result.importances[top3_idx].T, # replace sorted_idx with top3_idx
vert=False, labels=X_test.columns[top3_idx]) # replace sorted_idx with top3_idx
【讨论】:
以上是关于使用排列重要性绘制前 n 个特征的主要内容,如果未能解决你的问题,请参考以下文章
Python计算树模型(随机森林xgboost等)的特征重要度及其波动程度:基于熵减的特征重要度计算及可视化基于特征排列的特征重要性(feature permutation)计算及可视化