为啥“StratifiedShuffleSplit”对数据集的每个拆分都给出相同的结果?

Posted

技术标签:

【中文标题】为啥“StratifiedShuffleSplit”对数据集的每个拆分都给出相同的结果?【英文标题】:Why does "StratifiedShuffleSplit" give the same result for every split of dataset?为什么“StratifiedShuffleSplit”对数据集的每个拆分都给出相同的结果? 【发布时间】:2021-06-03 15:35:01 【问题描述】:

我正在使用StratifiedShuffleSplit 重复拆分数据集、拟合、预测和计算指标的过程。您能否解释一下为什么每次拆分都给出相同的结果?

import csv
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report

clf = RandomForestClassifier(max_depth = 5)
df = pd.read_csv("https://raw.githubusercontent.com/leanhdung1994/BigData/main/cll_dataset.csv")
X, y = df.iloc[:, 1:], df.iloc[:, 0]
sss = StratifiedShuffleSplit(n_splits = 5, test_size = 0.25, random_state = 0).split(X, y)

for train_ind, test_ind in sss:
    X_train, X_test = X.loc[train_ind], X.loc[test_ind]
    y_train, y_test = y.loc[train_ind], y.loc[test_ind]
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    report = classification_report(y_test, y_pred, zero_division = 0, output_dict = True)
    report = pd.DataFrame(report).T
    report = report[:2]
    print(report)

结果是

   precision  recall  f1-score  support
0       0.75     1.0  0.857143      6.0
1       0.00     0.0  0.000000      2.0
   precision  recall  f1-score  support
0       0.75     1.0  0.857143      6.0
1       0.00     0.0  0.000000      2.0
   precision  recall  f1-score  support
0       0.75     1.0  0.857143      6.0
1       0.00     0.0  0.000000      2.0
   precision  recall  f1-score  support
0       0.75     1.0  0.857143      6.0
1       0.00     0.0  0.000000      2.0
   precision  recall  f1-score  support
0       0.75     1.0  0.857143      6.0
1       0.00     0.0  0.000000      2.0

【问题讨论】:

【参考方案1】:

您构建的每个模型都预测输出始终为 0 类,并且由于您已分层拆分(0 类和 1 类的比例始终与 X 相同),因此您始终预测完全相同的值。

与“学习”某些模式或规则相比,模型始终预测 0 类时会获得更好的准确性。这是一个巨大的问题。要解决它,您有以下一些选择:

尝试修改随机森林算法的一些超参数。 收集更多数据以获得更大的数据集,您只测试 8 个样本(可能是 很难为您获取新数据) 您的数据不平衡(0 类样本多于 1 类样本),您 应该考虑使用SMOTE 库来平衡它

【讨论】:

以上是关于为啥“StratifiedShuffleSplit”对数据集的每个拆分都给出相同的结果?的主要内容,如果未能解决你的问题,请参考以下文章

“TypeError:'StratifiedShuffleSplit'对象不可迭代”的原因可能是啥?

这个错误对 StratifiedShuffleSplit 意味着啥?

StratifiedKFold vs StratifiedShuffleSplit vs StratifiedKFold + Shuffle

在 cross_val_predict (sklearn) 中使用 StratifiedShuffleSplit

sklearn可视化不同数据划分方法的差异:KFold, ShuffleSplit,StratifiedKFold, GroupKFold, StratifiedShuffleSplit.......

带有索引的 scikit-learn StratifiedShuffleSplit KeyError