scikit-learn:应用任意函数作为管道的一部分
Posted
技术标签:
【中文标题】scikit-learn:应用任意函数作为管道的一部分【英文标题】:scikit-learn: applying an arbitary function as part of a pipeline 【发布时间】:2017-08-08 05:45:39 【问题描述】:我刚刚发现了 scikit-learn 的 Pipeline 功能,我发现它对于在训练我的模型之前测试预处理步骤的不同组合非常有用。
管道是实现fit
和transform
方法的对象链。现在,如果我想添加一个新的预处理步骤,我曾经编写一个继承自sklearn.base.estimator
的类。但是,我认为必须有一个更简单的方法。我真的需要将要应用的每个函数都包装在估算器类中吗?
例子:
class Categorizer(sklearn.base.BaseEstimator):
"""
Converts given columns into pandas dtype 'category'.
"""
def __init__(self, columns):
self.columns = columns
def fit(self, X, y):
return self
def transform(self, X):
for column in self.columns:
X[column] = X[column].astype("category")
return X
【问题讨论】:
【参考方案1】:我认为值得一提的是sklearn.preprocessing.FunctionTransformer(..., validate=True)
有一个validate=False
参数:
验证:
bool
,可选default=True
表示在调用
func
之前应检查输入X
数组。如果 validate 为 false,则不会进行输入验证。 如果是 为真,则X
将被转换为二维 NumPy 数组或 稀疏矩阵。如果无法进行此转换或X
包含NaN
或 无穷大,引发异常。
因此,如果您要将非数字特征传递给FunctionTransformer
,请确保您明确设置validate=False
,否则它将失败并出现以下异常:
ValueError: could not convert string to float: 'your non-numerical value'
【讨论】:
【参考方案2】:对于通用解决方案(适用于许多其他用例,不仅是转换器,还包括简单模型等),如果您有无状态函数(可以不实现 fit),例如通过这样做:
class TransformerWrapper(sklearn.base.BaseEstimator):
def __init__(self, func):
self._func = func
def fit(self, *args, **kwargs):
return self
def transform(self, X, *args, **kwargs):
return self._func(X, *args, **kwargs)
现在你可以做
@TransformerWrapper
def foo(x):
return x*2
相当于做
def foo(x):
return x*2
foo = TransformerWrapper(foo)
这就是 sklearn.preprocessing.FunctionTransformer 在幕后所做的。
我个人觉得装饰更简单,因为您可以很好地将预处理器与其余代码分开,但要遵循哪条路径取决于您。
其实你应该可以通过 sklearn 函数来装饰
from sklearn.preprocessing import FunctionTransformer
@FunctionTransformer
def foo(x):
return x*2
也是。
【讨论】:
【参考方案3】:sklearn.preprocessing.FunctionTransformer
类可用于从用户提供的函数实例化一个 scikit-learn 转换器(可用于例如管道中)。
【讨论】:
不幸的是,FunctionTransformer 似乎强制输出是一个只有数字内容的 numpy ndarray,这在管道的每个阶段都不起作用。以上是关于scikit-learn:应用任意函数作为管道的一部分的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Scikit-learn 的管道中创建我们的自定义特征提取器函数并将其与 countvectorizer 一起使用
从磁盘加载包含预训练 Keras 模型的 scikit-learn 管道
如何创建一个应用 z-score 和交叉验证的 scikit-learn 管道?
scikit-learn:在管道中使用 SelectKBest 时获取选定的功能