class_weight 参数在 scikit-learn SGD 中的作用是啥

Posted

技术标签:

【中文标题】class_weight 参数在 scikit-learn SGD 中的作用是啥【英文标题】:What is class_weight parameter does in scikit-learn SGDclass_weight 参数在 scikit-learn SGD 中的作用是什么 【发布时间】:2015-05-22 05:51:39 【问题描述】:

我是 scikit-learn 的常客,我想了解有关 SGD 的“class_weight”参数的一些见解。

直到函数调用我才能弄清楚

plain_sgd(coef, intercept, est.loss_function,
                 penalty_type, alpha, C, est.l1_ratio,
                 dataset, n_iter, int(est.fit_intercept),
                 int(est.verbose), int(est.shuffle), est.random_state,
                 pos_weight, neg_weight,
                 learning_rate_type, est.eta0,
                 est.power_t, est.t_, intercept_decay)

https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/stochastic_gradient.py

在这之后它转到 sgd_fast 并且我对 cpython 不是很好。你能在这些问题上给出一些迅速。

    我在开发集中有一个类有偏差,其中正类是 15k,负类是 36k。 class_weight 会解决这个问题吗?或者进行欠采样将是一个更好的主意。我的数字越来越好,但很难解释。 如果是,那么它实际上是如何做到的。我的意思是它是应用于特征惩罚还是优化函数的权重。我该如何向外行解释?

【问题讨论】:

【参考方案1】:

class_weight 确实可以帮助提高在不平衡数据上训练的分类模型的 ROC AUC 或 f1-score。

您可以尝试class_weight="auto" 选择与班级频率成反比的权重。你也可以尝试传递你自己的权重有一个python字典,其中类标签作为键,权重作为值。

可以通过交叉验证的网格搜索来调整权重。

在内部,这是通过从class_weight 派生sample_weight 来完成的(取决于每个样本的类标签)。然后使用样本权重来衡量单个样本对损失函数的贡献,该损失函数用于训练具有随机梯度下降的线性分类模型。

特征惩罚通过penaltyalpha 超参数独立控制。 sample_weight/class_weight对它没有影响。

【讨论】:

以上是关于class_weight 参数在 scikit-learn SGD 中的作用是啥的主要内容,如果未能解决你的问题,请参考以下文章

scikit-learn:随机森林 class_weight 和 sample_weight 参数

如何在 sklearn 0.14 版中设置“class_weight”?

keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别?

随机森林中的 class_weight 超参数改变了混淆矩阵中的样本数量

class_weights 或加权损失在哪里惩罚网络?

sklearn 分类的 class_weight 字典格式