如何检查 CalibratedClassifierCV BaseEstimator 参数

Posted

技术标签:

【中文标题】如何检查 CalibratedClassifierCV BaseEstimator 参数【英文标题】:How to inspect CalibratedClassifierCV BaseEstimator parameters 【发布时间】:2021-03-05 22:56:30 【问题描述】:

我的任务是检查和比较由其他人开发的两个已经训练好的机器学习模型。模型的区别在于不同的输入数据集。第一个是在 2018 年的数据上训练的,第二个是在 2019 年的数据上训练的。它的核心是 RandomForestClassifier 模型,在 sklearn.pipeline 模块中训练。问题是,中间有一个CalibratedClassifierCV,这使我对随机森林模型本身的访问变得复杂。所以我对模型没有深入的了解,它对我来说就像一个黑盒子。两种情况下的管道是相同的。

编辑:在管道的创建方式中添加了可重现的步骤,但没有数据集:

from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

RF_clf = RandomForestClassifier()
pipeline = Pipeline([('scaling', StandardScaler(with_mean=False)),
                        ('classifier', CalibratedClassifierCV(base_estimator=RF_clf, cv=2, method='sigmoid'))])

# Further steps not replicated from code because of additional custom made training and fitting functions but these are the steps:

# fit model on train data
# predict model on test data

下一步我想做或看到的,从我从磁盘读取为model.pkl 文件的已经训练好的模型,是模型的.feature_importance_,因为随机森林本身就支持它。然后我会比较这些年份之间最重要特征的分布。但无法访问。

这是我在模型检查方面能走多远:

这些有效:

pipeline.named_steps
pipeline.named_steps['classifier']

CalibratedClassifierCV(base_estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features=1, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1,
            oob_score=False, random_state=0, verbose=1, warm_start=False),
            cv=2, method='sigmoid')

但我不能再深入了。

这个已经不行了:

pipeline.named_steps['classifier']['base_estimator']

TypeError: 'CalibratedClassifierCV' object is not subscriptable

我还尝试了eli5 库以查看一些信息,但似乎不支持CalibratedClassifierCV

eli5.explain_weights(pipeline_rf.named_steps['classifier'])

Error: estimator CalibratedClassifierCV(base_estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', max_depth=None, max_features=1, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1, oob_score=False, random_state=0, verbose=1, warm_start=False), cv=2, method='sigmoid') is not supported 

您是否有一些关于如何深入了解CalibratedClassifierCV 并真正了解BaseEstimator 模型的功能重要性的经验?感谢您提供任何提示。

【问题讨论】:

pipeline.named_steps['classifier'].base_estimator 怎么样? 请考虑minimal reproducible example。作为旁注,BaseEstimator 是一个父类,所有估计器都从该类继承。可能不是你想要的 【参考方案1】:

访问底层RandomForestClassifierfeature_importances_的简短示例。

from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn import set_config
set_config(print_changed_only=True)

X, y = make_classification()
rf = RandomForestClassifier().fit(X, y)

pipe = Pipeline([('classifier',
                  CalibratedClassifierCV(rf))]).fit(X, y)
pipe['classifier'].base_estimator.feature_importances_

如果您正在寻找CalibratedClassifierCVs 的输出比较,那么您需要查看calibrated_classifiers_ 属性。

更新:

对于 sigmoid 方法(这是默认方法),在拟合期间学习了两个参数 (a, b)。我们可以通过下面的 sn-p 提取这两个信息。

for calibrated_classifier in pipe['classifier'].calibrated_classifiers_ :
    calibrator = calibrated_classifier.calibrators_[0]
    print(calibrator.a_)
    print(calibrator.b_)

【讨论】:

嗨 Venkatachalam,感谢您的回答,我可以通过这种方式获得结果!太糟糕了,我找不到您提供的订阅示例,例如。 base_estimator.feature_importances_ 而不是 ['base_estimator'].feature_importances_。能否请您也给我举例说明如何比较CalibratedClassifierCV 对不起,你couldn't find example of the subscription的时候我听不懂,能详细点吗? 我的意思是,我在网上找不到关于如何正确调用函数参数的示例。这就是为什么我错误地尝试了['base_estimator'].feature_importances_ 而不是正确地尝试了base_estimator.feature_importances_。所以你能帮助我真是太好了,这个小小的帮助已经解决了我的问题。

以上是关于如何检查 CalibratedClassifierCV BaseEstimator 参数的主要内容,如果未能解决你的问题,请参考以下文章

如何检查元素是不是没有特定的类?

如何进行PYTHON语法检查

单击然后检查时如何设置字体真棒检查图标[关闭]

如何编写Windows安全检查脚本

如何检查网络关闭

如何检查 XmlAttributeCollection 中是不是存在属性?