scikit 学习。 GridSearchCV 管道中的自定义 Transformer set_params 逻辑。

Posted

技术标签:

【中文标题】scikit 学习。 GridSearchCV 管道中的自定义 Transformer set_params 逻辑。【英文标题】:scikit learn. Custom Transformer set_params logic in GridSearchCV pipeline. 【发布时间】:2018-02-04 13:20:35 【问题描述】:

我需要构建我的自定义转换器,在管道中使用它并评估它使用 GridSearchCV 调整该管道的参数。

按照here 的建议,我设法实现了简单的自定义转换器,但是 尝试使用内部估计器实现转换器并在 GridSearchCV 中使用此构造时出现问题。在我看来,我自己找不到答案,因为我不完全理解搜索方法(如(网格/随机化)SearchCV 和 set_params)的微妙之处。

“Python 机器学习简介”一书描述 GridSearchCV 的逻辑相当幼稚:

...iterating over each parameters combination...
    init estimator
    fit estimator
    evaluate

但是这种幼稚的方法无法回答我的问题。为了澄清我的问题,让我们看一下这个案例:

class OuterTransformer(BaseEstimator, TransformerMixin):
    _options = 'std':StandardScaler(),'mm':MinMaxScaler()
    def __init__(self, option='std'):
        ...

我的主要问题是“我应该把选择内部估计器的逻辑放在哪里?”。根据上面提到的帖子,这应该是这样的:

    def __init__(self, option='std'):
        self.option = option
    def fit(self, data, y=None):
        self.option = self._options[option] 
        ...

另一方面,常识规定 GridSearch 必须在调用 fit 之前传递参数来初始化内部估计器,因此应该在 __init__ 中选择内部估计器。

似乎第一种方法效果很好,但我就是不明白为什么。 有人可以向我解释一下这种现象吗?

【问题讨论】:

【参考方案1】:

看来我理解了估计器参数的初始化和重新初始化的逻辑。这有助于回答我的问题:

类字段必须用传递给构造函数的那些原始值来初始化,而不是它们的一些“衍生物”,因为对于每个重新-估计器的初始化,scikit 调用 __init__,传递在 CV 启动之前通过 get_params 方法从实例中提取的参数。

get_params的本质是扫描类的方法__init__的签名,并从名称对应于__init__ 的参数(当然 self 除外)。

因此,如果我们将 “派生” 值写入 __init__ 方法内的字段中,这些 “派生” 值将被转移到下一个重新初始化,这意味着一切都会失败。

class OuterTransformer(BaseEstimator, TransformerMixin):
    _options = 'std':StandardScaler(),'mm':MinMaxScaler()

    # good init- all fine
    def __init__(self, option='std'):
        self.option = option

    # bad init - will not work, because option is not an 'original' parameter.
    def __init__(self, option='std'):
        self.option = self._options[option] 

【讨论】:

以上是关于scikit 学习。 GridSearchCV 管道中的自定义 Transformer set_params 逻辑。的主要内容,如果未能解决你的问题,请参考以下文章

ValueError 在 Scikit 中找到最佳超参数时使用 GridSearchCV 学习 LogisticRegression

如何在 Scikit 中自定义 GridSearchCV 的指标 学习调整特定类?

scikit 学习。 GridSearchCV 管道中的自定义 Transformer set_params 逻辑。

了解 scikit-learn GridSearchCV - 参数调整和平均性能指标

Scikit Learn GridSearchCV 和 pipeline 使用不同的方法

评估 scikit-learn GridSearchCV 中交叉验证分数的平均值、标准差