带有索引的 scikit-learn StratifiedShuffleSplit KeyError
Posted
技术标签:
【中文标题】带有索引的 scikit-learn StratifiedShuffleSplit KeyError【英文标题】:scikit-learn StratifiedShuffleSplit KeyError with index 【发布时间】:2018-03-19 20:34:07 【问题描述】:这是我的熊猫数据框lots_not_preprocessed_usd
:
<class 'pandas.core.frame.DataFrame'>
Index: 78718 entries, 2017-09-12T18-38-38-076065 to 2017-10-02T07-29-40-245031
Data columns (total 20 columns):
created_year 78718 non-null float64
price 78718 non-null float64
........
decade 78718 non-null int64
dtypes: float64(8), int64(1), object(11)
memory usage: 12.6+ MB
头部(1):
artist_name_normalized house created_year description exhibited_in exhibited_in_museums height images max_estimated_price min_estimated_price price provenance provenance_estate_of sale_date sale_id sale_title style title width decade
key
2017-09-12T18-38-38-076065 NaN c11 1862.0 An Album and a small Quantity of unframed Draw... NaN NaN NaN NaN 535.031166 267.515583 845.349242 NaN NaN 1998-06-21 8033 OILS, WATERCOLOURS & DRAWINGS FROM 18TH - 20TH... watercolor painting An Album and a small Quantity of unframed Draw... NaN 186
我的脚本:
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42)
for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
strat_train_set = lots_not_preprocessed_usd.loc[train_index]
strat_test_set = lots_not_preprocessed_usd.loc[test_index]
我收到错误消息
KeyError Traceback (most recent call last)
<ipython-input-224-cee2389254f2> in <module>()
3 split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42)
4 for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
----> 5 strat_train_set = lots_not_preprocessed_usd.loc[train_index]
6 strat_test_set = lots_not_preprocessed_usd.loc[test_index]
......
KeyError: 'None of [[32199 67509 69003 ..., 44204 2809 56726]] are in the [index]'
我的索引(例如 2017-09-12T18-38-38-076065)似乎有问题,我不明白。问题出在哪里?
如果我使用另一个拆分,它会按预期工作:
from sklearn.model_selection import train_test_split
train_set, test_set = train_test_split(lots_not_preprocessed_usd, test_size=0.2, random_state=42)
【问题讨论】:
添加lots_not_preprocessed_usd.head()
以获得更多说明
【参考方案1】:
当您使用.loc
时,您需要为row_indexer 传递相同的索引,因此当您想使用普通数字索引器而不是.loc
时,请使用.iloc
。在 for 循环中 train_index 和 text_index 不是日期时间,因为split.split(X,y)
返回随机索引数组。
...
for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
strat_train_set = lots_not_preprocessed_usd.iloc[train_index]
strat_test_set = lots_not_preprocessed_usd.iloc[test_index]
示例示例
lots_not_preprocessed_usd = pd.DataFrame('some':np.random.randint(5,10,100),'decade':np.random.randint(5,10,100),index= pd.date_range('5-10-15',periods=100))
for train_index, test_index in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
strat_train_set = lots_not_preprocessed_usd.iloc[train_index]
strat_test_set = lots_not_preprocessed_usd.iloc[test_index]
样本输出:
strat_train_set.head()
十年一些
2015-08-02 6 7
2015-06-14 7 6
2015-08-14 7 9
2015-06-25 9 5
2015-05-15 7 9
【讨论】:
以上是关于带有索引的 scikit-learn StratifiedShuffleSplit KeyError的主要内容,如果未能解决你的问题,请参考以下文章
带有索引的 Scikit-learn train_test_split
IndexError:使用 scikit-learn 绘制 ROC 曲线时数组索引过多?
为啥带有铰链损失的 SGDClassifier 比 scikit-learn 中的 SVC 实现更快