随机森林中的 class_weight 超参数改变了混淆矩阵中的样本数量
Posted
技术标签:
【中文标题】随机森林中的 class_weight 超参数改变了混淆矩阵中的样本数量【英文标题】:class_weight hyperparameter in Random Forest change the amounts of samples in confusion matrix 【发布时间】:2018-04-15 05:04:24 【问题描述】:我目前正在研究一个包含 24,000 个样本的随机森林分类模型,其中 20,000 个属于 class 0
,其中 4,000 个属于 class 1
。我制作了一个train_test_split
,其中 test_set 是整个数据集的0.2
(test_set
中的大约 4,800 个样本)。由于我正在处理不平衡的数据,因此我查看了旨在解决此问题的超参数class_weight
。
我在设置class_weight='balanced'
时遇到的问题,看看训练集的confusion_matrix
,我得到了类似的结果:
array([[13209, 747],
[ 2776, 2468]])
如您所见,下方数组对应False Negative = 2776
,后跟True Positive = 2468
,而上方数组对应True Negative = 13209
,后跟False Positive = 747
。问题是样本量属于class 1
,根据confusion_matrix
是2,776 (False Negative) + 2,468 (True Positive)
,加起来5,244 samples
属于class 1
。这没有任何意义,因为整个数据集仅包含 4,000 个属于 class 1
的样本,其中只有 3,200 个位于 train_set
中。看起来confusion_matrix
返回了矩阵的Transposed
版本,因为在training_set
中属于class 1
的样本的实际数量应该在train_set
中总计3,200 个样本,在test_set
中总计800 个样本。一般来说,正确的数字应该是 747 + 2468,总和为 3,215,这是属于class 1
的正确样本数量。
有人可以解释一下我使用class_weight
时会发生什么吗? confusion_matrix
返回矩阵的transposed
版本是真的吗?我看错了吗?
我尝试寻找答案并访问了几个在某种程度上相似的问题,但没有一个真正涵盖了这个问题。
这些是我查看的一些来源:
scikit-learn: Random forest class_weight and sample_weight parameters
How to tune parameters in Random Forest, using Scikit Learn?
https://datascience.stackexchange.com/questions/11564/how-does-class-weights-work-in-randomforestclassifier
https://stats.stackexchange.com/questions/244630/difference-between-sample-weight-and-class-weight-randomforest-classifier
using sample_weight and class_weight in imbalanced dataset with RandomForest Classifier
任何帮助将不胜感激,谢谢。
【问题讨论】:
【参考方案1】:从docs复制玩具示例:
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 1]
y_pred = [1, 1, 1, 0]
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
(tn, fp, fn, tp)
# (0, 2, 1, 1)
所以,您提供的混淆矩阵的阅读似乎是正确的。
confusion_matrix 返回一个转置的版本是真的吗? 矩阵?
如上例所示,不。但是一个非常简单(而且看起来很无辜)的错误可能是您交换了y_true
和y_pred
参数的顺序,这很重要;结果确实是一个转置矩阵:
# correct order of arguments:
confusion_matrix(y_true, y_pred)
# array([[0, 2],
# [1, 1]])
# inverted (wrong) order of the arguments:
confusion_matrix(y_pred, y_true)
# array([[0, 1],
# [2, 1]])
从您提供的信息中无法判断这是否是原因,这很好地提醒了您为什么应该始终提供实际代码,而不是对您的想法进行口头描述 你的代码正在做...
【讨论】:
这确实是我的问题,我在你发布答案前 2 分钟就想到了,也非常感谢你的解释,现在问题已经解决了。干杯。以上是关于随机森林中的 class_weight 超参数改变了混淆矩阵中的样本数量的主要内容,如果未能解决你的问题,请参考以下文章
scikit-learn:随机森林 class_weight 和 sample_weight 参数
使用 GridSearchCV 调整 scikit-learn 的随机森林超参数