sklearn 中处理网格搜索组合的顺序是啥?

Posted

技术标签:

【中文标题】sklearn 中处理网格搜索组合的顺序是啥?【英文标题】:What order are grid search combinations handled in sklearn?sklearn 中处理网格搜索组合的顺序是什么? 【发布时间】:2021-02-27 01:40:19 【问题描述】:

我对 sklearn 的 GridSearchCV 对象处理其超参数组合的顺序有疑问。具体来说,我使用带有参数的 sklearn 执行了网格搜索:

param1 = [val1, val2, val3, val4, val5]
param2 = [num1, num2]

cv_results_mean_test_score 属性是一个长度为 10 的数组,正如预期的那样(len(param1)*len(param2));但是,我不知道哪个值对应于什么组合。也就是说,param1 的值是否被保持,param2 被循环,反之亦然。

mean_test_score中的10个值是否对应

[ [val1, num1], [val1, num2], [val2, num1], [val2, num2], ... ]

(其中param2param1 之前循环)或

[ [val1, num1], [va2, num1], [val3, num1], [val4, num1], [val5, num1], [val1, num2], ... ]

(其中param1param2 之前循环)。它是否仅取决于它们在网格搜索中指定的顺序?我可以根据一个特定的超参数值返回结果吗?

谢谢!

【问题讨论】:

【参考方案1】:

GridSearchCV 在内部使用名为 ParameterGrid 的类,您可以查看 here(第 47、114 行)

这或多或少是ParameterGrid 在您的GridSearchCV 中所做的:

from itertools import product

grid_values= ["param1": [1, 2, 3, 4, 5], "param2": [1, 2]]

def grid(grid_values):
    for p in grid_values:
        # Always sort the keys of a dictionary, for reproducibility
        print(p)
        items = sorted(p.items())
        if not items:
            yield 
        else:
            keys, values = zip(*items)
            for v in product(*values):
                params = dict(zip(keys, v))
                yield params

它首先将你的字典包装在一个列表中(因为它可以处理不同类型的数据作为输入,例如字典列表)

grid_values= ["param1": [1, 2, 3, 4, 5], "param2": [1, 2]]

之后,它会对您的 dict 的键执行排序,以实现可重复性。这将决定你的组合

  items = sorted(p.items())

然后它使用来自itertoolsproduct 函数,它执行您的想法(here details)。变量上的嵌套 for 循环。但是从按参数名称排序的值开始!

for v in product(*values):
    params = dict(zip(keys, v))
    yield params

Check also the doc of ParameterGrid

【讨论】:

【参考方案2】:

如果你这样做

import pandas as pd
pd.DataFrame(clf.cv_results_)

param_param1param_param2 列将为您提供每个组合的相应参数。

当然,你也可以使用一个通用的索引来迭代它,但是使用 pandas 是很容易的。

【讨论】:

我认为OP对如何生成参数组合的顺序很感兴趣

以上是关于sklearn 中处理网格搜索组合的顺序是啥?的主要内容,如果未能解决你的问题,请参考以下文章

使用管道和网格搜索执行特征选择

在处理 VotingClassifier 或网格搜索时,Sklearn 中的 GradientBoostingClassifier 是不是有类权重(或替代方式)?

管道和网格搜索的 SKLearn 错误

在 sklearn 中制作网格搜索功能以忽略空模型

scikit-learn 中的超参数优化(网格搜索)

使用网格搜索获得最佳模型的“并行”管道