sklearn 的 train_test_split 中的“Stratify”参数无法正常工作?
Posted
技术标签:
【中文标题】sklearn 的 train_test_split 中的“Stratify”参数无法正常工作?【英文标题】:"Stratify" parameter from sklearn's train_test_split not working correctly? 【发布时间】:2017-02-12 21:08:46 【问题描述】:scikit-learn 的train_test_split()
函数中的stratify
参数有问题。这是一个虚拟示例,在我的数据中随机出现相同的问题:
from sklearn.model_selection import train_test_split
a = [1, 0, 0, 0, 0, 0, 0, 1]
train_test_split(a, stratify=a, random_state=42)
返回:
[[1, 0, 0, 0, 0, 1], [0, 0]]
它不应该在测试子集中也选择一个“1”吗?从我期望train_test_split()
和stratify
的工作方式来看,它应该返回如下内容:
[[1, 0, 0, 0, 0, 0], [0, 1]]
random_state
的某些值会发生这种情况,而其他值则可以正常工作;但是每次我必须分析数据时,我都无法搜索它的“正确”值。
我有 python 2.7 和 scikit-learn 0.18。
【问题讨论】:
如果您尝试使用stratify=np.unique(a)
会怎样?
很遗憾,它不起作用,因为传递给stratify
的列表必须与要拆分的列表长度相同。
文档中没有任何地方声明即使在很小的子集中也会有所有类。如果您将唯一的 1 添加到列表中,那么您将在测试拆分中获得 1 类。我认为它应该与您的火车拆分中的第 1 类部分相同。例如,如果您删除“分层”,那么您将得到列表的尾部,而不是带有混洗类的列表。
【参考方案1】:
这个问题是 8 个月前提出的,但我想答案可能对未来的读者仍有帮助。
当使用stratify
参数时,train_test_split
实际上是依赖StratifiedShuffleSplit
函数进行拆分。正如您在documentation 中看到的那样,StratifiedShuffleSplit
确实旨在通过保留每个类的样本百分比来进行拆分,正如您所期望的那样。
问题是,在您的示例中,25%(8 个样本中的 2 个)是 1s,但样本量不足以让您在测试集上看到这一比例。您有两种选择:
A. 使用选项test_size
(默认为 0.25)增加测试集的大小,例如 0.5。在这种情况下,一半的样本将成为您的测试集,您会看到其中 25%(即 4 分之一)是 1。
>>> a = [1, 0, 0, 0, 0, 0, 0, 1]
>>> train_test_split(a, stratify=a, random_state=42, test_size=0.5)
[[1, 0, 0, 0], [0, 0, 1, 0]]
B. 将test_size
保留为其默认值并增加您的集合a
的大小,使其25% 的样本至少包含4 个元素。 16 个或更多样本的 a
将为您做到这一点。
>>> a = [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1]
>>> train_test_split(a, stratify=a, random_state=42)
[[0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0]]
希望对您有所帮助。
【讨论】:
感谢您的回答!很有帮助!以上是关于sklearn 的 train_test_split 中的“Stratify”参数无法正常工作?的主要内容,如果未能解决你的问题,请参考以下文章