如何在 Scikit-Learn 中获取 GridSearchCV() 的 OneVsRestClassifier(LinearSVC()) 的估算器键参考?

Posted

技术标签:

【中文标题】如何在 Scikit-Learn 中获取 GridSearchCV() 的 OneVsRestClassifier(LinearSVC()) 的估算器键参考?【英文标题】:How do I get the estimator key reference to OneVsRestClassifier(LinearSVC()) for GridSearchCV() in Scikit-Learn? 【发布时间】:2020-01-15 17:55:17 【问题描述】:

我正在 scikit-learn 中对GridSearchCV 的超参数进行网格搜索。

这就是我准备 ML 算法及其要搜索的相关参数的方式。 LogisticRegression()RandomForestClassifier() 分别用它们正确的估计键 logisticregression__randomforestclassifier__ 指定。

ml_algo_param_dict = \
                   'LR_OVR': 'clf': LogisticRegression(),
                                'param': [
                                    'logisticregression__solver': ['lbfgs', 'liblinear'],
                                    'logisticregression__penalty': ['l2'],
                                    'logisticregression__C': [0.1, 1, 10],
                                    'logisticregression__class_weight': [None],
                                    'logisticregression__multi_class': ['ovr'],
                                    'logisticregression__max_iter': [1000, 4000],
                                , 
                                    'logisticregression__solver': ['newton-cg'],
                                    'logisticregression__penalty': ['l2'],
                                    'logisticregression__C': [0.1, 1, 10],
                                    'logisticregression__class_weight': [None],
                                    'logisticregression__multi_class': ['ovr'],
                                    'logisticregression__max_iter': [1000, 4000],
                                ],
                    'RF_OVR': 'clf': RandomForestClassifier(),
                                'param': [
                                    'randomforestclassifier__n_estimators': [100],
                                    'randomforestclassifier__max_depth': [150, 200],
                                    'randomforestclassifier__random_state': [888],
                                ],
                    'SVC_OVR': 'clf': OneVsRestClassifier(LinearSVC()),
                                'param': [
                                        'onevsrestclassifier_linearsvc__C': [100],
                                        'onevsrestclassifier_linearsvc__max_iter': [400, 6000],
                                ],

但是OneVsRestClassifier(LinearSVC()) 呢?我尝试了很多方法(即onevsrestclassifier_linearsvc__onevsrestclassifier__linearsvc__),但一直收到错误Check the list of available parameters with estimator.get_params().keys()。如何找到正确的估算器键?


添加以下代码以显示 dict 的使用方式

transformer_num = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())])

transformer_cat = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='constant', fill_value='')),
    ('onehotencoder', OneHotEncoder(handle_unknown='ignore'))])

preprocessor = ColumnTransformer(
    transformers=[
        ('num', transformer_num, feature_list_num),
        ('cat', transformer_cat, feature_list_cat),
        ])

for algo_key, algo_val in ml_algo_param_dict.items():
    f1 = make_scorer(f1_score , average='micro')
    pipe = make_pipeline(preprocessor, algo_val['clf'])
    grid = GridSearchCV(pipe, algo_val['param'], n_jobs=-1, cv=5, scoring=f1, refit=True)
    grid.fit(X_train, y_train)

我试过'onevsrestclassifier_linearsvc__C', onevsrestclassifier_linearsvc_estimator__C', 'onevsrestclassifier__C', 'linearsvc__C', 'onevsrestclassifier__linearsvc__C', 'onevsrestclassifier-linearsvc__C', 'onevsrestclassifier_linearsvc_estimator__C', 'estimator__C',但都给了我同样的错误Check the list of available parameters with "estimator.get_params().keys()"

【问题讨论】:

【参考方案1】:

以下是正常工作的引用命名:

'SVC_OVR': 'clf': OneVsRestClassifier(LinearSVC()),
            'param': [
                'onevsrestclassifier__estimator__C': [1, 10],
                'onevsrestclassifier__estimator__max_iter': [10000],
                      ],

【讨论】:

以上是关于如何在 Scikit-Learn 中获取 GridSearchCV() 的 OneVsRestClassifier(LinearSVC()) 的估算器键参考?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Scikit-Learn 中获取 GridSearchCV() 的 OneVsRestClassifier(LinearSVC()) 的估算器键参考?

scikit-learn:在管道中使用 SelectKBest 时获取选定的功能

Extjs4.1 Grid-如何在 itemdblclick 函数中获取 cellIndex

使用 scikit-learn 进行文本分类:如何从 pickle 模型中获取新文档的表示

如何获取 Scikit-learn 的 svm 中的训练误差?

如何从 scikit-learn KMeans 中获取聚类中心的文本?