如何使用 scikit-learn API 实现元估计器?

Posted

技术标签:

【中文标题】如何使用 scikit-learn API 实现元估计器?【英文标题】:How to implement a meta-estimator with the scikit-learn API? 【发布时间】:2020-03-07 07:32:59 【问题描述】:

我想实现一个与所有 scikit-learn 兼容的简单包装器/元估计器。很难找到我到底需要什么的完整描述。

目标是有一个回归器,它也学习一个阈值来成为一个分类器。所以我想出了:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

这是否实现了我需要的完整 API?

我的主要问题是把threshold 放在哪里。我希望它只学习一次,并且可以在后续 .fit 调用中重新使用新数据而无需重新调整。但是对于当前版本,必须在每次调用 .fit 时重新调整它——我不希望这样吗?

另一方面,如果我将其设为固定参数self.threshold 并将其传递给__init__,那么我不应该用数据更改它吗?

我怎样才能创建一个threshold 参数,该参数可以在.fit 的一次调用中调整并在后续.fit 调用中修复?

【问题讨论】:

请问多次fit 电话的原因是什么?这是某种在线学习吗?还是由于交叉验证?还是别的什么? @ShihabShahriarKhan 只有一个与特定数据集合(以及存储在最佳阈值中的特定测试数据)的拟合才能确定阈值。从那时起,我希望不再调整阈值,而只调整我的数据折叠的回归量。 我可能误解了一些东西...如果您在init中将threshold_初始化为None,并检查它的值是否设置在fit中,它不会工作吗?有点类似于warm_startparam 很难说它是否有效,因为 sklearn 有一定的 API 和函数,如 clone 和其他元估计器在幕后做了某种魔术。这就是为什么我想知道使用 sklearn 的正确方法。 为什么不在构造函数中初始化 self.threshold = None 而不是 if 语句 - if self.threshold is not None: self.threshold = optimal_threshold(y_raw)?虽然我认为更好的方法是在 fit 方法中添加一个布尔值,指示是否更新阈值 【参考方案1】:

I actually wrote a blog post about this the other day。我假设您正在尝试构建类似于 TransformedTargetRegressor 的东西,我建议您查看其源代码以构建类似的东西。

您当前的实现似乎是正确的。就这个问题而言:

如何制作一个阈值参数,该参数可以在一次调用 .fit 中进行调整,并为后续的.fit 调用进行修复?

我不建议这样做,因为 scikit-learn 的 API 是基于 fit 方法重新拟合模型的所有可调方面。您可以在此处使用两条路线,将**kwarg 添加到明确保护theshold 不更新的匹配项,或者您可以使用@rotem-tal suggested。如果你选择后者,它可能看起来像这样:

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin

def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
    return np.array([0.1, 0.5, 1])  # some implementation here

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        self.threshold = None

    def fit(self, X, y, optimal_threshold):
        # you don't need to clone the regressor
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        if self.threshold is None:
            self.threshold = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

【讨论】:

您需要遵循 sklearn API 否则将无法正常工作。我只是不知道 sklearn 的意图。您的版本将不起作用,因为clone 不会复制.threshold 属性。但是,cross_validate 使用 clone 。因此cross_validate 无法使用我的固定阈值!? 不确定您的意思,正如我在帖子中提到的那样,您尝试做的事情是非标准的,并且可能无法与生态系统的其他部分(管道等)很好地集成,但是这绝对是可能的。 我正在尝试找到一种方法来制作满足标准的估算器。并非所有不在 sklearn 中的东西都会自动成为非标准。 就编程接口而言,以上将符合 API 标准,试一试。不过,“从概念上讲”,它不符合标准,因为在调用 fit 时阈值不会更新。 如果不支持sklearn的关键功能,则不符合API标准。制作具有相似名称的函数不足以遵循 API,因为这个概念也必须得到满足。此版本的代码不适用于具有固定阈值的cross_validate。我试了一下,由于解释的原因,它只是不起作用。 cross_validate 将为每个折叠重新设置阈值。

以上是关于如何使用 scikit-learn API 实现元估计器?的主要内容,如果未能解决你的问题,请参考以下文章

Scikit-learn 安装 - 准备元数据 (pyproject.toml) ... 错误

如何在 scikit-learn 的 SVM 中使用非整数字符串标签? Python

k近邻算法api初步使用

如何使用 scikit-learn 进行高斯/多项式回归?

k近邻算法api初步使用

python机器学习——使用scikit-learn训练感知机模型