在 cross_val_predict (sklearn) 中使用 StratifiedShuffleSplit
Posted
技术标签:
【中文标题】在 cross_val_predict (sklearn) 中使用 StratifiedShuffleSplit【英文标题】:use StratifiedShuffleSplit in cross_val_predict (sklearn) 【发布时间】:2020-12-07 20:54:14 【问题描述】:我正在尝试使用有监督的机器学习来根据作物(例如土豆)各自的长度和宽度测量值来预测其重量。在拟合特定模型(例如线性回归)之前,我想根据我的数据集中特定作物品种的频率对我的特征进行分层样本。例如,如果我将数据分成 5 个分区(即我使用交叉验证),并且品种 1 占我观察的 50%,则每个分区训练集中 50% 的观察应该对应于品种 1。这是我使用 sklearn(版本 0.23)在 Python 中尝试过的代码:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import cross_val_predict
from sklearn.linear_model import LinearRegression
# build pd.DataFrame
varieties = np.concatenate([np.repeat("variety1", 10),
np.repeat("variety2", 30),
np.repeat("variety3", 60)])
columns = "variety": varieties,
"length": np.random.randint(30, 70, size=100),
"width": np.random.randint(40, 50, size=100),
"weight": np.random.random(100)*100 + 50
df = pd.DataFrame(columns)
# stratified sampling
kf = StratifiedShuffleSplit(n_splits=5, test_size=0.2)
# fit model based on a cv splitter
lm = LinearRegression()
X = df.loc[:,"length":"width"]
y = df["weight"]
y_pred = cross_val_predict(lm, X, y, cv=kf.split(X, df["variety"]))
但是,当我运行此代码时,我收到以下错误:
ValueError: cross_val_predict only works for partitions
这对我来说有点令人惊讶,因为根据documentation of sklearn 我们可以在cross_val_predict 的cv 参数中使用拆分器。我知道我可以使用 for 循环来完成我想要的:
kf = StratifiedShuffleSplit(n_splits=5, test_size=0.2)
X = df.loc[:,"length":"width"]
y = df["weight"]
y_pred = np.zeros(y.size)
for train_idx, test_idx in kf.split(X, df["variety"]):
#get subsets of variables from CV
X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
#fit model
lm.fit(X_train, y_train)
pred_vals = lm.predict(X_test)
#store predicted values
y_pred[test_idx] = pred_vals
但是,我更喜欢使用 cross_val_predict 来使代码更紧凑。有可能吗?
【问题讨论】:
【参考方案1】:尝试使用StratifiedKFold
而不是StratifiedShuffleSplit
。
不同之处在于StratifiedKFold只是shuffle和split一次,因此测试集不会重叠,而StratifiedShuffleSplit每次分裂前都会shuffle,并且拆分n_splits次,测试集可以重叠,并且某些数据分区永远不会成为其中的一部分测试数据集,这意味着没有对它们的预测。
您可以在Catbuilts's explanation阅读更多内容
【讨论】:
以上是关于在 cross_val_predict (sklearn) 中使用 StratifiedShuffleSplit的主要内容,如果未能解决你的问题,请参考以下文章
为啥 cross_val_predict 比适合 KNeighborsClassifier 慢得多?
为啥 cross_val_predict 不适合测量泛化误差?
在 cross_val_predict (sklearn) 中使用 StratifiedShuffleSplit
为啥当 cv=5 时 cross_val_predict 只返回单个预测数组