如何为 GridSearchCV 提供交叉验证的索引列表?

Posted

技术标签:

【中文标题】如何为 GridSearchCV 提供交叉验证的索引列表?【英文标题】:How to give GridSearchCV a list of indicies for cross-validation? 【发布时间】:2018-10-05 05:02:19 【问题描述】:

我正在尝试对一个非常具体的数据集使用自定义交叉验证集,并使用 BayesSearchCV 使用 scikit-optimize。我已经能够使用GridSearchCV 复制scikit-learn 的错误。

直接来自documentation:

cv : int,交叉验证生成器或可迭代的,可选的

确定交叉验证拆分策略。可能的输入 简历是:

None,使用默认的3折交叉验证,整数,指定 (分层)KFold 中的折叠数,用作 交叉验证生成器。一个可迭代的产量火车,测试拆分。 对于整数/无输入,如果估计器是分类器并且 y 是 使用二元或多类,StratifiedKFold。在所有其他 情况下,使用 KFold。

请参阅用户指南,了解可以使用的各种交叉验证策略 在这里使用。

我不能在我的特定数据集中使用cv=10。这只是为了说明错误。

我想使用文档中所述的交叉验证训练-测试拆分列表列表。如何正确格式化我的交叉验证列表?

# Generate data
def iris_data(noise=None, palette="hls", desat=1):
    # Iris dataset
    X = pd.DataFrame(load_iris().data,
                     index = [*map(lambda x:f"iris_x", range(150))],
                     columns = [*map(lambda x: x.split(" (cm)")[0].replace(" ","_"), load_iris().feature_names)])

    y = pd.Series(load_iris().target,
                           index = X.index,
                           name = "Species")
    cmap = map_colors(y, mode=1, palette=palette, desat=desat)#y.map(lambda x:0:"red",1:"green",2:"blue"[x])

    if noise is not None:
        X_noise = pd.DataFrame(
            np.random.RandomState(0).normal(size=(X.shape[0], noise)),
            index=X_iris.index,
            columns=[*map(lambda x:f"noise_x", range(noise))]
        )
        X = pd.concat([X, X_noise], axis=1)
    return (X, y, cmap)

X, y, c = iris_data(noise=50)

# Get cross-validations
cv = list()
for i in range(10):
    idx_tr = np.random.choice(np.arange(X.shape[0]),size=100, replace=False)
    idx_te = set(range(X.shape[0])) - set(idx_tr)
    tr_te_splits = [idx_tr.tolist(), list(idx_te)]
    cv.append(tr_te_splits)

# Get hyperparameter searchspace
search_spaces = 
    "n_estimators": [1,10,50],
    "criterion": ["gini", "entropy"],
    "max_features": ["sqrt", "log2", None],
    "min_samples_leaf": [1,2,3,5,8,13],


opt = GridSearchCV(RandomForestClassifier(random_state=0), search_spaces, scoring="accuracy", n_jobs=1, cv=cv)
opt.fit(X,y)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-26-d1117d10dfa6> in <module>()
     59 
     60 opt = GridSearchCV(RandomForestClassifier(random_state=0), search_spaces, scoring="accuracy", n_jobs=1, cv=cv)
---> 61 opt.fit(X,y)

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    637                                   error_score=self.error_score)
    638           for parameters, (train, test) in product(candidate_params,
--> 639                                                    cv.split(X, y, groups)))
    640 
    641         # if one choose to see train score, "out" will contain train score info

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
    777             # was dispatched. In particular this covers the edge
    778             # case of Parallel used with an exhausted iterator.
--> 779             while self.dispatch_one_batch(iterator):
    780                 self._iterating = True
    781             else:

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py in dispatch_one_batch(self, iterator)
    623                 return False
    624             else:
--> 625                 self._dispatch(tasks)
    626                 return True
    627 

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py in _dispatch(self, batch)
    586         dispatch_timestamp = time.time()
    587         cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
--> 588         job = self._backend.apply_async(batch, callback=cb)
    589         self._jobs.append(job)
    590 

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/_parallel_backends.py in apply_async(self, func, callback)
    109     def apply_async(self, func, callback=None):
    110         """Schedule a func to be run"""
--> 111         result = ImmediateResult(func)
    112         if callback:
    113             callback(result)

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/_parallel_backends.py in __init__(self, batch)
    330         # Don't delay the application, to avoid keeping the input
    331         # arguments in memory
--> 332         self.results = batch()
    333 
    334     def get(self):

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py in __call__(self)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
    132 
    133     def __len__(self):

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/externals/joblib/parallel.py in <listcomp>(.0)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
    132 
    133     def __len__(self):

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/model_selection/_validation.py in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, error_score)
    446     start_time = time.time()
    447 
--> 448     X_train, y_train = _safe_split(estimator, X, y, train)
    449     X_test, y_test = _safe_split(estimator, X, y, test, train)
    450 

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/utils/metaestimators.py in _safe_split(estimator, X, y, indices, train_indices)
    198             X_subset = X[np.ix_(indices, train_indices)]
    199     else:
--> 200         X_subset = safe_indexing(X, indices)
    201 
    202     if y is not None:

~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/utils/__init__.py in safe_indexing(X, indices)
    144     if hasattr(X, "iloc"):
    145         # Work-around for indexing with read-only indices in pandas
--> 146         indices = indices if indices.flags.writeable else indices.copy()
    147         # Pandas Dataframes and Series
    148         try:

AttributeError: 'list' object has no attribute 'flags'

)

【问题讨论】:

【参考方案1】:

由于输入对象Xypandas,我相信它们需要命名索引。如果我通过.values 方法将它们转换为numpy,那么它就可以工作。如果你这样做,你只需要确保订单是正确的。

【讨论】:

以上是关于如何为 GridSearchCV 提供交叉验证的索引列表?的主要内容,如果未能解决你的问题,请参考以下文章

如何仅使用 GridSearchCV 进行简单的交叉验证

使用 Scikit-Learn GridSearchCV 与 PredefinedSplit 进行交叉验证 - 可疑的交叉验证结果

使用 Imblearn 管道和 GridSearchCV 进行交叉验证

如何在 python 中使用交叉验证执行 GridSearchCV

使用 Keras 和 sklearn GridSearchCV 交叉验证提前停止

交叉验证是如何执行的以及 GridSearchCV() 具体是如何执行的?