继承自SciKit FunctionTransformer
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了继承自SciKit FunctionTransformer相关的知识,希望对你有一定的参考价值。
我想使用FunctionTransformer
,同时提供一个简单的API并隐藏其他详细信息。具体来说,我希望能够提供一个Custom_Trans
类,如下所示。因此,除了可以正常工作的trans1
之外,用户还应该可以使用目前失败的trans2
:
from sklearn import preprocessing
from sklearn.pipeline import Pipeline
from sklearn import model_selection
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np
X, y = make_regression(n_samples=100, n_features=1, noise=0.1)
def func(X, a, b):
return X[:,a:b]
class Custom_Trans(preprocessing.FunctionTransformer):
def __init__(self, ind0, ind1):
super().__init__(
func=func,
kw_args={
"a": ind0,
"b": ind1
}
)
trans1 = preprocessing.FunctionTransformer(
func=func,
kw_args={
"a": 0,
"b": 50
}
)
trans2 = Custom_Trans(0,50)
pipe1 = Pipeline(
steps=[
('custom', trans1),
('linear', LinearRegression())
]
)
pipe2 = Pipeline(
steps=[
('custom', trans2),
('linear', LinearRegression())
]
)
print(model_selection.cross_val_score(
pipe1, X, y, cv=3,)
)
print(model_selection.cross_val_score(
pipe2, X, y, cv=3,)
)
这就是我得到的:
[0.99999331 0.99999671 0.99999772]
...sklearn/base.py:209: FutureWarning: From version 0.24, get_params will raise an
AttributeError if a parameter cannot be retrieved as an instance attribute.
Previously it would return None.
warnings.warn('From version 0.24, get_params will raise an '
...
[0.99999331 0.99999671 0.99999772]
我有点知道这与估计器克隆有关,但我不知道如何解决。例如this post表示
估计器中应该没有逻辑,甚至没有输入验证init。逻辑应放在使用参数的位置,通常适合]
但是在这种情况下,我需要将参数传递给超类。无法将逻辑放入fit()
中。我该怎么办?
答案
您可以通过从BaseEstimator继承来获取'get_params'。
class FunctionTransformer(BaseEstimator, TransformerMixin)
How to pass parameters to the customize modeltransformer class
inherit from function_transformer
您在基础中有此:
def get_params(self, deep=True):
"""
Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
更改您的代码:
trans1 = dict(functiontransformer__kw_args = [{'ind0':无},{'ind0':1}])
class Custom_Trans(preprocessing.FunctionTransformer):
def __init__(self, ind0, ind1, deep=True):
super().__init__( func=func, kw_args={ "a": ind0, "b": ind1 } )
self.ind0 = ind0
self.ind1 = ind1
self.deep = True
以上是关于继承自SciKit FunctionTransformer的主要内容,如果未能解决你的问题,请参考以下文章
为 scikit-learn 估计器子类化 XGBoostRegressor 会收到“TypeError:super() 不接受关键字参数”。