用于管道的网格搜索参数网格的说明

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用于管道的网格搜索参数网格的说明相关的知识,希望对你有一定的参考价值。

'feature_selection__k': list(range(1, len(feature_importances) + 1))代码在什么意思:

param_grid = [{
    'preparation__num__imputer__strategy': ['mean', 'median', 'most_frequent'],
    'feature_selection__k': list(range(1, len(feature_importances) + 1))
}]

grid_search_prep = GridSearchCV(prepare_select_and_predict_pipeline, param_grid, cv=5,
                                scoring='neg_mean_squared_error', verbose=2)
grid_search_prep.fit(housing, housing_labels)

where

full_pipeline = ColumnTransformer([
        ("num", num_pipeline, num_attribs),
        ("cat", OneHotEncoder(), cat_attribs),
    ])

num_pipeline = Pipeline([
        ('imputer', SimpleImputer(strategy="median")),
        ('attribs_adder', CombinedAttributesAdder()),
        ('std_scaler', StandardScaler()),
    ])

您能解释一下'feature_selection__k': list(range(1, len(feature_importances) + 1))行中的每一步吗?


如果需要所有代码,则在这里:https://github.com/ageron/handson-ml2/blob/master/02_end_to_end_machine_learning_project.ipynb。我要问的代码部分在笔记本底部。

答案

因为GridSearchCV在这里不适用于简单的估算器,但在管道中适用:

prepare_select_and_predict_pipeline = Pipeline([
    ('preparation', full_pipeline),
    ('feature_selection', TopFeatureSelector(feature_importances, k)),
    ('svm_reg', SVR(**rnd_search.best_params_))
])

where

full_pipeline = ColumnTransformer([
        ("num", num_pipeline, num_attribs),
        ("cat", OneHotEncoder(), cat_attribs),
    ])

num_pipeline = Pipeline([
        ('imputer', SimpleImputer(strategy="median")),
        ('attribs_adder', CombinedAttributesAdder()),
        ('std_scaler', StandardScaler()),
    ])

param_grid需要组织成“级别”,以便它知道它将访问组成管道组件的确切参数。

因此,字符串的feature_selection部分是指prepare_select_and_predict_pipeline的各个部分,k是指TopFeatureSelector的各个参数。级别由双下划线__分隔,因此要访问k的参数TopFeatureSelector,相应的参数定义为feature_selection__k

出于相同的原因,为了使param_grid访问strategySimpleImputer参数,相应的条目为'preparation__num__imputer__strategy',即:

  • 第一个preparationprepare_select_and_predict_pipeline部分
  • 第二个numfull_pipeline部分
  • imputernum_pipeline部分
  • strategy的参数SimpleImputer

每个条目,如前所述,用双下划线__分隔。

另一答案

[list(range(1, len(feature_importances) + 1))返回一个从1到n的整数列表,其中n是feature_importances的长度。

例如,如果feature_importances的长度为5。

list(range(1, len(feature_importances) + 1))

将返回:

[1, 2, 3, 4, 5]

您需要添加“ +1”,因为范围不包括上限。

以上是关于用于管道的网格搜索参数网格的说明的主要内容,如果未能解决你的问题,请参考以下文章

scikit-learn 管道:对变压器参数进行网格搜索以生成数据

如何在 scikit-learn 的管道中对变换参数进行网格搜索

在 Pipeline 上进行网格搜索后更新变压器参数

使用管道进行岭回归网格搜索

使用网格搜索获得最佳模型的“并行”管道

网格搜索预处理多个超参数和多个估计器