keras/scikit-learn:使用 fit_generator() 进行交叉验证
Posted
技术标签:
【中文标题】keras/scikit-learn:使用 fit_generator() 进行交叉验证【英文标题】:keras/scikit-learn: using fit_generator() with cross validation 【发布时间】:2017-04-12 18:16:31 【问题描述】:是否可以将Keras's scikit-learn API 与fit_generator()
方法一起使用?或者使用另一种方式来产生批次进行训练?我正在使用 SciPy 的稀疏矩阵,在输入到 Keras 之前必须将其转换为 NumPy 数组,但由于内存消耗高,我无法同时转换它们。这是我产生批次的功能:
def batch_generator(X, y, batch_size):
n_splits = len(X) // (batch_size - 1)
X = np.array_split(X, n_splits)
y = np.array_split(y, n_splits)
while True:
for i in range(len(X)):
X_batch = []
y_batch = []
for ii in range(len(X[i])):
X_batch.append(X[i][ii].toarray().astype(np.int8)) # conversion sparse matrix -> np.array
y_batch.append(y[i][ii])
yield (np.array(X_batch), np.array(y_batch))
以及带有交叉验证的示例代码:
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn import datasets
from keras.models import Sequential
from keras.layers import Activation, Dense
from keras.wrappers.scikit_learn import KerasClassifier
import numpy as np
def build_model(n_hidden=32):
model = Sequential([
Dense(n_hidden, input_dim=4),
Activation("relu"),
Dense(n_hidden),
Activation("relu"),
Dense(3),
Activation("sigmoid")
])
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
return model
iris = datasets.load_iris()
X = iris["data"]
y = iris["target"].flatten()
param_grid =
"n_hidden": np.array([4, 8, 16]),
"nb_epoch": np.array(range(50, 61, 5))
model = KerasClassifier(build_fn=build_model, verbose=0)
skf = StratifiedKFold(n_splits=5).split(X, y) # this yields (train_indices, test_indices)
grid = GridSearchCV(model, param_grid, cv=skf, verbose=2, n_jobs=4)
grid.fit(X, y)
print(grid.best_score_)
print(grid.cv_results_["params"][grid.best_index_])
为了更详细地解释它,它使用param_grid
中所有可能的超参数组合来构建模型。然后在StratifiedKFold
提供的训练测试数据拆分 (folds) 上对每个模型进行训练和测试。那么给定模型的最终分数是所有折叠的平均分数。
那么是否有可能在上面的代码中插入一些预处理子步骤以在实际拟合之前转换数据(稀疏矩阵)?
我知道我可以编写自己的交叉验证生成器,但它必须生成索引,而不是真实数据!
【问题讨论】:
github.com/cerlymarco/keras-hypetune 【参考方案1】:实际上,您可以通过生成器使用稀疏矩阵作为 Keras 的输入。这是我在以前项目中使用的版本:
> class KerasClassifier(KerasClassifier):
> """ adds sparse matrix handling using batch generator
> """
>
> def fit(self, x, y, **kwargs):
> """ adds sparse matrix handling """
> if not issparse(x):
> return super().fit(x, y, **kwargs)
>
> ############ adapted from KerasClassifier.fit ######################
> if self.build_fn is None:
> self.model = self.__call__(**self.filter_sk_params(self.__call__))
> elif not isinstance(self.build_fn, types.FunctionType):
> self.model = self.build_fn(
> **self.filter_sk_params(self.build_fn.__call__))
> else:
> self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
>
> loss_name = self.model.loss
> if hasattr(loss_name, '__name__'):
> loss_name = loss_name.__name__
> if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
> y = to_categorical(y)
> ### fit => fit_generator
> fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit_generator))
> fit_args.update(kwargs)
> ############################################################
> self.model.fit_generator(
> self.get_batch(x, y, self.sk_params["batch_size"]),
> samples_per_epoch=x.shape[0],
> **fit_args)
> return self
>
> def get_batch(self, x, y=None, batch_size=32):
> """ batch generator to enable sparse input """
> index = np.arange(x.shape[0])
> start = 0
> while True:
> if start == 0 and y is not None:
> np.random.shuffle(index)
> batch = index[start:start+batch_size]
> if y is not None:
> yield x[batch].toarray(), y[batch]
> else:
> yield x[batch].toarray()
> start += batch_size
> if start >= x.shape[0]:
> start = 0
>
> def predict_proba(self, x):
> """ adds sparse matrix handling """
> if not issparse(x):
> return super().predict_proba(x)
>
> preds = self.model.predict_generator(
> self.get_batch(x, None, self.sk_params["batch_size"]),
> val_samples=x.shape[0])
> return preds
【讨论】:
这看起来不错——我也想到了修改 Keras 的源代码,但我想避免这种情况。谢谢,我会试试的:)以上是关于keras/scikit-learn:使用 fit_generator() 进行交叉验证的主要内容,如果未能解决你的问题,请参考以下文章
keras + scikit-learn 包装器,当 GridSearchCV 与 n_jobs >1 时似乎挂起
Keras scikit-learn 包装器在使用 one-hot 编码标签的交叉验证中的评分指标