使用 GridSearchCV 时跳过禁止的参数组合

Posted

技术标签:

【中文标题】使用 GridSearchCV 时跳过禁止的参数组合【英文标题】:Skip forbidden parameter combinations when using GridSearchCV 【发布时间】:2017-08-17 23:39:54 【问题描述】:

我想使用GridSearchCV 贪婪地搜索我的支持向量分类器的整个参数空间。但是,LinearSVC 和throw an exception 禁止某些参数组合。特别是dualpenaltyloss参数的组合是互斥的:

例如这段代码:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

iris = datasets.load_iris()
parameters = 'dual':[True, False], 'penalty' : ['l1', 'l2'], \
              'loss': ['hinge', 'squared_hinge']
svc = svm.LinearSVC()
clf = GridSearchCV(svc, parameters)
clf.fit(iris.data, iris.target)

返回ValueError: Unsupported set of arguments: The combination of penalty='l2' and loss='hinge' are not supported when dual=False, Parameters: penalty='l2', loss='hinge', dual=False

我的问题是:是否可以让 GridSearchCV 跳过模型禁止的参数组合?如果没有,有没有简单的方法来构造一个不会违反规则的参数空间?

【问题讨论】:

如果我们至少可以在这种情况下抑制 FitFailedWarning 语句,这仍然是一个问题,但问题较小。我面临同样的战斗,我知道某些组合是非法的,但防止这些组合的逻辑(如下所述)太丑陋了。 【参考方案1】:

我通过将error_score=0.0 传递给GridSearchCV 解决了这个问题:

error_score : ‘raise’(默认)或数字

要分配给 如果在估计器拟合中发生错误,则得分。如果设置为“raise”,则 引发错误。如果给出一个数值,FitFailedWarning 是 提高。此参数不影响改装步骤,这将 总是引发错误。

更新:较新版本的 sklearn 打印出一堆 ConvergenceWarningFitFailedWarning。我很难用contextlib.suppress 压制他们,但there is a hack around that 涉及测试上下文管理器:

from sklearn import svm, datasets 
from sklearn.utils._testing import ignore_warnings 
from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 
from sklearn.model_selection import GridSearchCV 

with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]): 
    iris = datasets.load_iris() 
    parameters = 'dual':[True, False], 'penalty' : ['l1', 'l2'], \ 
                 'loss': ['hinge', 'squared_hinge'] 
    svc = svm.LinearSVC() 
    clf = GridSearchCV(svc, parameters, error_score=0.0) 
    clf.fit(iris.data, iris.target)

【讨论】:

在它们实际输出任何错误之前,是否有一种解决方法可以实际避免这些组合(或任何其他组合)? @Khabz 我的答案太大而无法放入 cmets,所以我将其发布为另一个答案。 @crypdick 有没有办法避免在结果中看到 FitFailedWarning? @Nihat 我编辑了我的答案以抑制新警告【参考方案2】:

如果您想完全避免探索特定组合(无需等待遇到错误),您必须自己构建网格。 GridSearchCV 可以获取一个字典列表,其中探索列表中每个字典跨越的网格。

在这种情况下,条件逻辑并没有那么糟糕,但是对于更复杂的事情来说,它真的很乏味:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from itertools import product

iris = datasets.load_iris()

duals = [True, False]
penaltys = ['l1', 'l2']
losses = ['hinge', 'squared_hinge']
all_params = list(product(duals, penaltys, losses))
filtered_params = ['dual': [dual], 'penalty' : [penalty], 'loss': [loss]
                   for dual, penalty, loss in all_params
                   if not (penalty == 'l1' and loss == 'hinge') 
                   and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                  and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]

svc = svm.LinearSVC()
clf = GridSearchCV(svc, filtered_params)
clf.fit(iris.data, iris.target)

【讨论】:

感谢您的努力,但这似乎是一个略显粗略的解决方案,对于具有大量限制的问题会导致大量冗长 @Khabz 同意,这段代码被诅咒了!如果有无数个条件,一种可能性是以编程方式在filtered_params 中构造条件列表,然后是str.join(conditionals_list),最后是eval() 字符串以进行列表理解。

以上是关于使用 GridSearchCV 时跳过禁止的参数组合的主要内容,如果未能解决你的问题,请参考以下文章

在 NGINX 中执行 301 重定向时跳过参数

使用 mysqldump 时跳过或忽略临时表

push_back 时跳过向量位置

在 netbeans 7 中,如何在构建 maven 项目时跳过测试并添加 maven 附加参数?

SpringBoot项目maven 打包时跳过测试

如何使用GridSearchCV获取所有模型(每组参数一个)?