scikit-learn:随机森林 class_weight 和 sample_weight 参数

Posted

技术标签:

【中文标题】scikit-learn:随机森林 class_weight 和 sample_weight 参数【英文标题】:scikit-learn: Random forest class_weight and sample_weight parameters 【发布时间】:2015-08-28 14:08:07 【问题描述】:

我有一个类不平衡问题,并且一直在使用 scikit-learn (>= 0.16) 中的实现来试验加权随机森林。

我注意到该实现在树构造函数中采用 class_weight 参数,在 fit 方法中采用 sample_weight 参数来帮助解决类不平衡问题。尽管这两个似乎相乘以决定最终权重。

我无法理解以下内容:

在树构建/训练/预测的哪些阶段使用这些权重?我看过一些关于加权树的论文,但我不确定 scikit 实现了什么。 class_weight 和 sample_weight 到底有什么区别?

【问题讨论】:

【参考方案1】:

RandomForests 是建立在树上的,树上有很好的记录。检查 Trees 如何使用样本权重:

User guide on decision trees - 准确说明所使用的算法 Decision tree API - 解释树如何使用 sample_weight(对于随机森林,正如您所确定的,它是 class_weight 和 sample_weight 的乘积)。

至于class_weightsample_weight 之间的区别:很大程度上可以通过它们的数据类型的性质来确定。 sample_weight 是长度为 n_samples 的一维数组,为用于训练的每个示例分配显式权重。 class_weight 要么是每个类的字典到该类的统一权重(例如,1:.9, 2:.5, 3:.01),要么是告诉 sklearn 如何自动确定该字典的字符串。

因此,给定示例的训练权重是其明确命名为 sample_weight(或 1,如果未提供 sample_weight)的乘积,它是 class_weight(或 1,如果 class_weight 是未提供)。

【讨论】:

您对 DT 作为基本分类器是正确的。我还对在训练期间(例如,决定决策节点的杂质等)和预测期间如何使用这些权重感兴趣。 fit()方法下查看sample_weight的文档:Decision tree API 虽然不是一清二楚,但我想等我消化一下链接里的材料,一切都会更清楚了:)

以上是关于scikit-learn:随机森林 class_weight 和 sample_weight 参数的主要内容,如果未能解决你的问题,请参考以下文章

使用 scikit-learn 并行生成随机森林

在 scikit-learn 中平均多个随机森林模型

scikit-learn 中的随机森林解释

scikit-learn 随机森林过多的内存使用

随机森林中的引导数(scikit-learn)

scikit-learn 随机森林永远不会完成训练,冻结