train_test_split 函数是不是保持类之间的平衡

Posted

技术标签:

【中文标题】train_test_split 函数是不是保持类之间的平衡【英文标题】:Does the train_test_split function keep the balance between classestrain_test_split 函数是否保持类之间的平衡 【发布时间】:2019-07-03 04:43:54 【问题描述】:

我有一个问题,我一直在寻找答案,但我找不到答案。

如果我有一个使用三个或更多类标记的数据集,其中每个类代表 33% 的数据。当我拆分数据时,训练/验证/测试集是否在类之间保持相同的平衡?

如果没有,有没有办法保持平衡?

提前致谢。

【问题讨论】:

Stratified Train/Test-split in scikit-learn的可能重复 【参考方案1】:

找到了!

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

【讨论】:

这是做什么的? 它对训练/测试集中的数据进行分层并保持类的数量平衡,例如,如果您有 100 个 class1 和 100 个 class2,当您以 0.2 测试大小拆分时,您将得到一个训练集有 80 个 class1 和 80 个 class2 以及一个包含 20 个 class1 和 20 个 class2 的测试集

以上是关于train_test_split 函数是不是保持类之间的平衡的主要内容,如果未能解决你的问题,请参考以下文章

我们如何将显式测试数据和训练数据提供给 SVM,而不是使用 train_test_split 函数?

用 numpy 编写一个 train_test_split 函数

train_test_split()函数

sklearn的train_test_split函数

为什么在train_test_split的两个数组中都包含目标类?

TypeError:train_test_split() 只有当我在函数中写入参数'test_size'时才获得多个值