如何为 Scikit-learn 分类器添加加权损失?
Posted
技术标签:
【中文标题】如何为 Scikit-learn 分类器添加加权损失?【英文标题】:How to add weighted loss to Scikit-learn classifiers? 【发布时间】:2021-06-22 05:27:04 【问题描述】:在许多机器学习应用中,可能需要加权损失,因为某些类型的错误预测可能比其他错误更糟糕。例如。在医学二元分类(健康/生病)中,假阴性(患者没有得到进一步检查)比假阳性(后续检查会发现错误)更糟糕。
所以如果我这样定义一个加权损失函数:
def weighted_loss(prediction, target):
if prediction == target:
return 0 # correct, no loss
elif prediction == 0: # class 0 is healthy
return 100 # false negative, very bad
else:
return 1 # false positive, incorrect
如何将与此等效的内容传递给 scikit-learn 分类器,例如 Random Forests 或 SVM 分类器?
【问题讨论】:
你的意思是class_weight吗? 我不确定。对我来说,班级权重意味着不仅损失而且奖励(正确地上课)都会得到提升,对吧?是否有更深入的解释 class_weight 的作用?我找不到。 class_weight 用于不平衡数据集,其中每个类中有不同数量的样本;为了不训练一个偏向于具有大量样本的类的模型,class_weight 就派上用场了。通过根据您拥有的类数为每个类分配不同的权重,如果当前样本用于训练,则深度神经网络的模型权重不会发生太大变化,反之亦然样本。 好吧,我没有不平衡的数据集,我想人为地使损失不平衡,因为 FP 比 FN 更可取。我从您的评论中得到的是 class_weights 不是我的问题的答案,对吧? 是的,class_weights 不是您问题的答案。但是,您可以做的是开发一个模型,然后使用 sklearn.metrics.classification_report 查看结果。你需要的是高精度分数和相对较高的召回分数。 【参考方案1】:我担心你的问题是不恰当的,源于 loss 和 metric 的不同概念之间的根本混淆。
Loss 函数不适用于prediction == target
-type 条件 - 这就是 metrics(如准确度、精确度、召回率等) do - 然而,它在损失优化(即训练)期间不起作用,仅用于性能评估。损失不适用于硬类预测;它仅适用于分类器的概率输出,而这种相等条件从不适用。
损失和指标之间的额外“绝缘”层是阈值的选择,这是转换分类器的概率输出所必需的(训练期间唯一重要的事情)到“硬”类预测(仅对正在考虑的业务问题重要的事情)。再一次,这个阈值在模型训练期间绝对没有作用(唯一相关的量是损失,它对阈值和硬类预测一无所知);很好地放入交叉验证线程Reduce Classification Probability Threshold:
当您为新样本的每个类别输出一个概率时,您的练习的统计部分就结束了。选择一个阈值,将新观察分类为 1 与 0 的阈值不再是 统计数据 的一部分。它是 decision 组件的一部分。
虽然您当然可以尝试在狭窄定义的模型训练(即损失最小化)之外使用 额外 程序优化此(决策)阈值,正如您在 cmets 中简要描述的那样,您的期望是
我很确定,如果 RBF 绘制的决策边界在拟合数据时考虑到这一点,我会得到更好的结果
使用类似于您的 weight_loss
功能的东西是徒劳的。
因此,不能使用与此处显示的 weight_loss
类似的函数(本质上是一个度量,而不是损失函数,尽管它的名称),它采用像 prediction == target
这样的相等条件,可以使用用于模型训练。
以下 SO 线程中的讨论也可能有助于澄清问题:
Loss & accuracy - Are these reasonable learning curves? What is the difference between loss function and metric in Keras?(尽管有标题,但定义普遍适用,不仅适用于 Keras) Cost function training target versus accuracy desired goal How to interpret loss and accuracy for a machine learning model【讨论】:
以上是关于如何为 Scikit-learn 分类器添加加权损失?的主要内容,如果未能解决你的问题,请参考以下文章