如何使用 scikit-learn API 实现元估计器?
Posted
技术标签:
【中文标题】如何使用 scikit-learn API 实现元估计器?【英文标题】:How to implement a meta-estimator with the scikit-learn API? 【发布时间】:2020-03-07 07:32:59 【问题描述】:我想实现一个与所有 scikit-learn 兼容的简单包装器/元估计器。很难找到我到底需要什么的完整描述。
目标是有一个回归器,它也学习一个阈值来成为一个分类器。所以我想出了:
from sklearn.base import BaseEstimator, ClassifierMixin, clone
class Thresholder(BaseEstimator, ClassifierMixin):
def __init__(self, regressor):
self.regressor = regressor
# threshold_ does not get initialized in __init__ ??
def fit(self, X, y, optimal_threshold):
self.regressor = clone(self.regressor) # is this required my sklearn??
self.regressor.fit(X, y)
y_raw = self.regressor.predict()
self.threshold_ = optimal_threshold(y_raw)
def predict(self, X):
y_raw = self.regressor.predict(X)
y = np.digitize(y_raw, [self.threshold_])
return y
这是否实现了我需要的完整 API?
我的主要问题是把threshold
放在哪里。我希望它只学习一次,并且可以在后续 .fit
调用中重新使用新数据而无需重新调整。但是对于当前版本,必须在每次调用 .fit
时重新调整它——我不希望这样吗?
另一方面,如果我将其设为固定参数self.threshold
并将其传递给__init__
,那么我不应该用数据更改它吗?
我怎样才能创建一个threshold
参数,该参数可以在.fit
的一次调用中调整并在后续.fit
调用中修复?
【问题讨论】:
请问多次fit
电话的原因是什么?这是某种在线学习吗?还是由于交叉验证?还是别的什么?
@ShihabShahriarKhan 只有一个与特定数据集合(以及存储在最佳阈值中的特定测试数据)的拟合才能确定阈值。从那时起,我希望不再调整阈值,而只调整我的数据折叠的回归量。
我可能误解了一些东西...如果您在init中将threshold_
初始化为None,并检查它的值是否设置在fit
中,它不会工作吗?有点类似于warm_start
param
很难说它是否有效,因为 sklearn 有一定的 API 和函数,如 clone
和其他元估计器在幕后做了某种魔术。这就是为什么我想知道使用 sklearn 的正确方法。
为什么不在构造函数中初始化 self.threshold = None
而不是 if 语句 - if self.threshold is not None: self.threshold = optimal_threshold(y_raw)
?虽然我认为更好的方法是在 fit
方法中添加一个布尔值,指示是否更新阈值
【参考方案1】:
I actually wrote a blog post about this the other day。我假设您正在尝试构建类似于 TransformedTargetRegressor
的东西,我建议您查看其源代码以构建类似的东西。
您当前的实现似乎是正确的。就这个问题而言:
如何制作一个阈值参数,该参数可以在一次调用
.fit
中进行调整,并为后续的.fit
调用进行修复?
我不建议这样做,因为 scikit-learn
的 API 是基于 fit
方法重新拟合模型的所有可调方面。您可以在此处使用两条路线,将**kwarg
添加到明确保护theshold
不更新的匹配项,或者您可以使用@rotem-tal suggested。如果你选择后者,它可能看起来像这样:
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
return np.array([0.1, 0.5, 1]) # some implementation here
class Thresholder(BaseEstimator, ClassifierMixin):
def __init__(self, regressor):
self.regressor = regressor
self.threshold = None
def fit(self, X, y, optimal_threshold):
# you don't need to clone the regressor
self.regressor.fit(X, y)
y_raw = self.regressor.predict()
if self.threshold is None:
self.threshold = optimal_threshold(y_raw)
def predict(self, X):
y_raw = self.regressor.predict(X)
y = np.digitize(y_raw, [self.threshold_])
return y
【讨论】:
您需要遵循 sklearn API 否则将无法正常工作。我只是不知道 sklearn 的意图。您的版本将不起作用,因为clone
不会复制.threshold
属性。但是,cross_validate
使用 clone
。因此cross_validate
无法使用我的固定阈值!?
不确定您的意思,正如我在帖子中提到的那样,您尝试做的事情是非标准的,并且可能无法与生态系统的其他部分(管道等)很好地集成,但是这绝对是可能的。
我正在尝试找到一种方法来制作满足标准的估算器。并非所有不在 sklearn 中的东西都会自动成为非标准。
就编程接口而言,以上将符合 API 标准,试一试。不过,“从概念上讲”,它不符合标准,因为在调用 fit
时阈值不会更新。
如果不支持sklearn的关键功能,则不符合API标准。制作具有相似名称的函数不足以遵循 API,因为这个概念也必须得到满足。此版本的代码不适用于具有固定阈值的cross_validate
。我试了一下,由于解释的原因,它只是不起作用。 cross_validate
将为每个折叠重新设置阈值。以上是关于如何使用 scikit-learn API 实现元估计器?的主要内容,如果未能解决你的问题,请参考以下文章
Scikit-learn 安装 - 准备元数据 (pyproject.toml) ... 错误