Sklearn:分组数据的交叉验证

Posted

技术标签:

【中文标题】Sklearn:分组数据的交叉验证【英文标题】:Sklearn: Cross validation for grouped data 【发布时间】:2017-03-15 04:15:42 【问题描述】:

我正在尝试对分组数据实施交叉验证方案。我希望使用 GroupKFold 方法,但我一直收到错误消息。我究竟做错了什么? 代码(与我使用的代码略有不同——我有不同的数据,所以我有一个更大的 n_splits,但其他一切都是一样的)

from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import GroupKFold
from sklearn.grid_search import GridSearchCV
from xgboost import XGBRegressor
#generate data
x=np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13])
y= np.array([1,2,3,4,5,6,7,1,2,3,4,5,6,7])
group=np.array([1,0,1,1,2,2,2,1,1,1,2,0,0,2)]
#grid search
gkf = GroupKFold( n_splits=3).split(x,y,group)
subsample = np.arange(0.3,0.5,0.1)
param_grid = dict( subsample=subsample)
rgr_xgb = XGBRegressor(n_estimators=50)
grid_search = GridSearchCV(rgr_xgb, param_grid, cv=gkf, n_jobs=-1)
result = grid_search.fit(x, y)

错误:

Traceback (most recent call last):

File "<ipython-input-143-11d785056a08>", line 8, in <module>
result = grid_search.fit(x, y)

File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 813, in fit
return self._fit(X, y, ParameterGrid(self.param_grid))

 File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 566, in _fit
n_folds = len(cv)

TypeError: object of type 'generator' has no len()

换行

gkf = GroupKFold( n_splits=3).split(x,y,group)

gkf = GroupKFold( n_splits=3)

也不起作用。那么错误信息是:

'GroupKFold' object is not iterable

【问题讨论】:

你有什么版本的sklearnGridSearchCVcv 参数通常应该带一个生成器。 【参考方案1】:

GroupKFoldsplit 函数产生训练和测试指标一次一对。您应该在拆分值上调用 list 以将它们全部放在一个列表中,以便计算长度:

gkf = list(GroupKFold( n_splits=3).split(x,y,group))

【讨论】:

它似乎有效。但是,我尝试了gkf = list(GroupKFold( n_splits=3).split(x,y,group))gkf = list(GroupKFold( n_splits=3).split(x[:-100],y[:-100],group[:-100)),并且对于这两种情况,都使用grid_search.fit(x, y) 对其进行了训练。它们都运行平稳,结果几乎相同,而我预计第二个会失败(因为它在 gkf 上的元素比在 fit 上的元素少)。如何检查其行为? 还尝试了gkf = list(GroupKFold( n_splits=3).split(x,y,group))grid_search.fit(x[:100], y[:100]),它确实引发了一个奇怪的错误IndexError: index 100 is out of bounds for size 100

以上是关于Sklearn:分组数据的交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

交叉验证——Cross-validation

sklearn:用户定义的时间序列数据交叉验证

sklearn交叉验证-老鱼学sklearn

使用 sklearn 进行交叉验证的高级特征提取

sklearn:文本分类交叉验证中的向量化

sklearn中的交叉验证+决策树