scikit-learn 中的 LassoCV 如何分区数据?
Posted
技术标签:
【中文标题】scikit-learn 中的 LassoCV 如何分区数据?【英文标题】:How does LassoCV in scikit-learn partition data? 【发布时间】:2014-08-05 16:55:11 【问题描述】:我正在使用 sklearn 中的 Lasso 方法执行线性回归。
根据他们的指导以及我在其他地方看到的指导,建议不要简单地对所有训练数据进行交叉验证,而是将其拆分为更传统的训练集/验证集分区。
因此,Lasso 在训练集上进行训练,然后根据验证集的交叉验证结果调整超参数 alpha。最后,在测试集上使用接受的模型来给出一个真实的视图,哦它将如何在现实中执行。在这里分离关注点是防止过度拟合的一种预防措施。
实际问题
Lasso CV 是否符合上述协议,还是只是在同一数据和/或同一轮 CV 中以某种方式训练模型参数和超参数?
谢谢。
【问题讨论】:
【参考方案1】:如果您将sklearn.cross_validation.cross_val_score
与sklearn.linear_model.LassoCV
对象一起使用,那么您正在执行嵌套交叉验证。 cross_val_score
将根据您指定折叠的方式将您的数据分为训练集和测试集(可以使用 sklearn.cross_validation.KFold
等对象完成)。训练集将传递给LassoCV
,它本身会执行数据的另一次拆分,以便选择正确的惩罚。这似乎与您正在寻找的设置相对应。
import numpy as np
from sklearn.cross_validation import KFold, cross_val_score
from sklearn.linear_model import LassoCV
X = np.random.randn(20, 10)
y = np.random.randn(len(X))
cv_outer = KFold(len(X), n_folds=5)
lasso = LassoCV(cv=3) # cv=3 makes a KFold inner splitting with 3 folds
scores = cross_val_score(lasso, X, y, cv=cv_outer)
答案:否,LassoCV
不会为你做所有的工作,你必须结合使用它和cross_val_score
来获得你想要的东西。这同时也是实现此类对象的合理方式,因为我们也可以对仅拟合优化的超参数 LassoCV
感兴趣,而不必直接在另一组保留的数据上对其进行评估。
【讨论】:
只是为了确认:内部分裂的唯一目的是在LassoCV中选择“最好的”超参数C?如果模型不存在于this 列表中,那么进行超参数调整(比如 SVR)的推荐方法是使用 GridSearchCV 还是 RandomizedSearchCV?所以外部 CV 并没有改进模型,而只是检查如何在从未见过的数据上执行?如果使用简单的多元线性回归(无超参数),模型无法针对一般性能进行调整? 对所有这些问题持肯定态度。对于最后一个问题:调整模型的一种方法是是否包含列/特征。如果您使用sklearn.pipeline.Pipeline
,您可以预先添加一个特征选择器,例如sklearn.prepreprocessing.SelectKBest
到您的 OLS 中的管道并在 GridSearchCV
中使用此管道,后者检查 k
的不同数字。
嘿,在将LassoCV
嵌套在cross_val_score
中并在训练集上运行它之后,有没有办法检查拟合参数以在测试集上重新运行它们?
另外,如果你想使用 RMSE 评分,是否应该使用 cross_val_score
+Lasso
+GridSearchCV
而不是 cross_val_score
+LassoCV
进行嵌套交叉验证?
注意:有人试图编辑这篇文章以反映当前的sklearn
API,它使用sklearn.model_selection.cross_val_score/KFold
而不是sklearn.cross_validation.cross_val_score/KFold
,并使用n_splits
而不是n_folds
。此编辑被足够多的审阅者拒绝以做出决定,但不是错误的。我将保持原样,因为我认为它仍然有效(尽管有弃用警告),但如果有人想再次编辑,请继续。以上是关于scikit-learn 中的 LassoCV 如何分区数据?的主要内容,如果未能解决你的问题,请参考以下文章
sklearn、LassoCV() 和 ElasticCV() 坏了?
python使用lassocv生成影像组学(radiomic)模型的系数表