scikit-learn RandomForestClassifier 产生“意外”结果
Posted
技术标签:
【中文标题】scikit-learn RandomForestClassifier 产生“意外”结果【英文标题】:scikit-learn RandomForestClassifier produces 'unexpected' results 【发布时间】:2012-09-03 17:58:49 【问题描述】:我正在尝试将 sk-learn 的 RandomForestClassifier 用于二进制分类任务(正面和负面示例)。我的训练数据包含 1.177.245 个示例,具有 40 个特征,采用 SVM-light 格式(稀疏向量),我使用 sklearn.dataset 的 load_svmlight_file 加载。它产生一个“特征值”(1.177.245 * 40)的稀疏矩阵和一个“目标类”数组(1和0,其中1.177.245)。我不知道这是否令人担忧,但训练数据有 3552 个正数,其余均为负数。
由于 sk-learn 的 RFC 不接受稀疏矩阵,我使用 .toarray() 将稀疏矩阵转换为密集数组(如果我说的没错?很多 0 表示缺少的特征)。我在转换为数组之前和之后打印矩阵,这似乎一切正常。
当我启动分类器并开始将其拟合到数据时,需要很长时间:
[Parallel(n_jobs=40)]: Done 1 out of 40 | elapsed: 24.7min remaining: 963.3min
[Parallel(n_jobs=40)]: Done 40 out of 40 | elapsed: 27.2min finished
(输出对吗?那 963 分钟大约需要 2 个半......)
然后我使用 joblib.dump 转储它。 当我重新加载它时:
RandomForestClassifier: RandomForestClassifier(bootstrap=True, compute_importances=True,
criterion=gini, max_depth=None, max_features=auto,
min_density=0.1, min_samples_leaf=1, min_samples_split=1,
n_estimators=1500, n_jobs=40, oob_score=False,
random_state=<mtrand.RandomState object at 0x2b2d076fa300>,
verbose=1)
并在真实的训练数据上进行测试(由 750.709 个示例组成,与训练数据的格式完全相同)我得到“意外”的结果。准确地说;测试数据中只有一个示例被归类为真。当我对一半的初始训练数据进行训练并在另一半进行测试时,我根本没有得到任何阳性结果。
现在我没有理由相信正在发生的事情有什么问题,只是我得到了奇怪的结果,而且我认为这一切都完成得非常快。可能无法进行比较,但使用 rt-rank(也有 1500 次迭代,但有一半的核心)在相同数据上训练 RFClassifier 需要 12 多个小时......
谁能告诉我,我是否有任何理由相信某些事情没有按应有的方式工作?可能是训练数据中正负的比率吗?干杯。
【问题讨论】:
实际上,我在写这篇文章的时候只是查看了正负的确切比例,这对我来说似乎很合理。也许我的特征不足以区分大量的负面和少数正面? 虽然 handling unbalanced datasets in RF classifiers 有一些技术,但我认为它们中的任何一个都没有在 scikit-learn 中实现。 【参考方案1】:确实,这个数据集非常不平衡。我建议您对负样本进行二次抽样(例如随机选择n_positive_samples
)或对正样本进行过采样(后者更昂贵,但可能会产生更好的模型)。
您还确定您的所有特征都是数字特征(较大的值意味着现实生活中的某些东西)?如果其中一些是分类整数标记,则应将这些特征分解为 k 之一的布尔编码,而不是因为随机森林的 scikit-learn 实现无法直接处理分类数据。
【讨论】:
确实,极端不平衡似乎确实是 0-bias 的原因。我通过对负样本进行下采样并复制正样本进行了快速测试,并预测了更多的正样本。 @ogrisel 是否可以在 scikit-learn 中进行这种下采样?我似乎在任何地方都找不到 n_positive_samples 参数。 scikit-learn 中还没有内置的重采样器,但你可以使用标准的 numpy 花式索引。以上是关于scikit-learn RandomForestClassifier 产生“意外”结果的主要内容,如果未能解决你的问题,请参考以下文章
如何从 python 输出 RandomForest 分类器?