Scikit-Learn 中的分层标记 K 折交叉验证

Posted

技术标签:

【中文标题】Scikit-Learn 中的分层标记 K 折交叉验证【英文标题】:Stratified Labeled K-Fold Cross-Validation In Scikit-Learn 【发布时间】:2017-01-11 11:42:07 【问题描述】:

我正在尝试将数据集的实例分类为两个类之一,a 或 b。 B 是少数类,仅占数据集的 8%。所有实例都分配有一个 id,指示哪个主题生成了数据。因为每个主题生成的多个实例 ID 在数据集中经常重复。

下表只是一个例子,真实的表有大约 100,000 个实例。每个主题 ID 在表中有大约 100 个实例。正如您在下面的“larry”中看到的那样,每个主题都与一个课程相关联。

    * field  * field  *   id   *  class  
*******************************************
 0  *   _    *   _    *  bob   *    a
 1  *   _    *   _    *  susan *    a
 2  *   _    *   _    *  susan *    a
 3  *   _    *   _    *  bob   *    a
 4  *   _    *   _    *  larry *    b
 5  *   _    *   _    *  greg  *    a
 6  *   _    *   _    *  larry *    b
 7  *   _    *   _    *  bob   *    a
 8  *   _    *   _    *  susan *    a
 9  *   _    *   _    *  susan *    a
 10 *   _    *   _    *  bob   *    a
 11 *   _    *   _    *  greg  *    a
 ...   ...      ...      ...       ...

我想使用交叉验证来调整模型,并且必须对数据集进行分层,以便每个折叠都包含少数类 b 的一些示例。问题是我有第二个约束,相同的 id 决不能出现在两个不同的折叠中,因为这会泄漏有关主题的信息。

我正在使用 python 的 scikit-learn 库。我需要一种结合了 LabelKFold 和 StratifiedKFold 的方法,它可以确保标签(id)不会在折叠之间拆分,它可以确保每个折叠都具有相似的类比率。如何使用 scikit-learn 完成上述任务?如果无法在 sklearn 中拆分两个约束,我该如何手动或使用其他 python 库有效地拆分数据集?

【问题讨论】:

【参考方案1】:

以下在索引方面有点棘手(如果您使用 Pandas 之类的东西会有所帮助),但概念上很简单。

假设您创建了一个虚拟数据集,其中自变量只有 idclass。此外,在此数据集中,删除重复的 id 条目。

对于您的交叉验证,请在虚拟数据集上运行分层交叉验证。每次迭代:

    找出哪些ids 被选中用于训练和测试

    回到原始数据集,根据需要将属于id 的所有实例插入到训练和测试集中。

之所以有效,是因为:

    正如您所说,每个id 都与一个标签相关联。

    由于我们运行分层 CV,因此每个类都按比例表示。

    由于每个id 仅出现在训练集或测试集中(但不会同时出现),因此它也被标记了。

【讨论】:

这是完美的,谢谢。我很惊讶 scikit-learn 不支持如此常见的 CV 问题。 @ChrisF。你是怎么在你的代码中做到这一点的?我想做同样的事情,我使用过采样来复制少数类实例。我已经应用了 Label K-Fold 交叉验证迭代器,但我想将它与 Stratified K-Fold 结合使用以利用重复条目。 具体实现取决于您的数据集,但对我来说,我只是按照上面的说明一步一步地进行操作。我首先消除了所有重复的 id,然后运行简单的 sklearn 的 Stratified K-Fold 按类分层。然后我简单地保留了生成的折叠,但将每个 id 的其余数据添加到每个折叠中。这对我有用的唯一方法是因为我所有的 id 都有大约相同数量的实例,所以我仍然得到了与我正在寻找的相同百分比拆分以及相同的班级余额等。这也对我有用,因为每个 id 只有一个跨所有实例的类

以上是关于Scikit-Learn 中的分层标记 K 折交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

一种热编码标签和分层 K 折交叉验证

如何在 scikit-learn 中使用 k 折交叉验证来获得每折的精确召回?

如何计算分层 K 折交叉验证的不平衡数据集的误报率?

Scikit-Learn:在交叉验证期间避免数据泄漏

R中的分层k倍交叉验证

如果我们在管道中包含转换器,来自 scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 折交叉验证分数是不是存在偏差?