软标签上的 scikit-learn 分类

Posted

技术标签:

【中文标题】软标签上的 scikit-learn 分类【英文标题】:scikit-learn classification on soft labels 【发布时间】:2017-08-05 15:52:10 【问题描述】:

根据文档,可以为SGDClassifier 指定不同的损失函数。据我了解,log loss 是一个 cross-entropy 损失函数,理论上可以处理软标签,即以某些概率 [0,1] 给出的标签。

问题是:是否可以将SGDClassifierlog loss 一起使用来解决带有软标签的分类问题?如果不是 - 如何使用 scikit-learn 解决这个任务(软标签上的线性分类)?

更新:

target 的标记方式以及问题的性质,硬标签不会产生好的结果。但这仍然是一个分类问题(不是回归),我不想保留对prediction 的概率解释,所以回归也不能开箱即用。交叉熵损失函数可以自然地处理target 中的软标签。 scikit-learn 中线性分类器的所有损失函数似乎都只能处理硬标签。

所以问题大概是:

例如,如何为SGDClassifier 指定我自己的损失函数。似乎scikit-learn 并没有坚持这里的模块化方法,需要在其源代码中的某个地方进行更改

【问题讨论】:

【参考方案1】:

我最近遇到了这个问题,并想出了一个很好的解决方案,似乎可行。

基本上,使用反 sigmoid 函数将您的目标转换为对数赔率空间。然后拟合线性回归。然后,为了进行推理,从线性回归模型中获取预测的 sigmoid。

假设我们有软目标/标签y ∈ (0, 1)(确保将目标钳制为[1e-8, 1 - 1e-8],以避免在我们记录日志时出现不稳定问题)。

我们取逆 sigmoid,然后拟合线性回归(假设预测变量在矩阵 X 中):

y = np.clip(y, 1e-8, 1 - 1e-8)   # numerical stability
inv_sig_y = np.log(y / (1 - y))  # transform to log-odds-ratio space

from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X, inv_sig_y)

然后进行预测:

def sigmoid(x):
    ex = np.exp(x)
    return ex / (1 + ex)

preds = sigmoid(lr.predict(X_new))

这似乎可行,至少对于我的用例而言。我的猜测是,无论如何,LogisticRegression 的幕后发生的事情并不遥远。

奖励:这似乎也适用于 sklearn 中的其他回归模型,例如RandomForestRegressor.

【讨论】:

【参考方案2】:

According to the docs,

“log”损失给出逻辑回归,一种概率分类器。

通常,损失函数的形式为Loss( prediction, target ),其中prediction 是模型的输出,target 是真实值。在逻辑回归的情况下,prediction(0,1) 上的值(即“软标签”),而target01(即“硬标签”)。

所以在回答您的问题时,这取决于您指的是prediction 还是target。一般来说,标签的形式(“硬”或“软”)由prediction 选择的算法和target 的现有数据给出。

如果您的数据具有“硬”标签,并且您希望模型输出“软”标签(可以对其进行阈值化以给出“硬”标签),那么是的,逻辑回归属于此类。

如果您的数据具有“软”标签,那么在使用典型分类方法(即逻辑回归)之前,您必须选择一个阈值以将其转换为“硬”标签。否则,您可以使用适合模型的回归方法来预测“软”目标。在后一种方法中,您的模型可能会给出(0,1) 之外的值,而这必须得到处理。

【讨论】:

感谢您的回复。我指的是target。它的标签方式和问题的性质硬标签不会产生好的结果。但这仍然是一个分类问题(不是回归),我不想保留对prediction 的概率解释,所以回归也不能开箱即用。交叉熵损失函数可以自然处理target中的软标签 所以不,据我所知,SGDClassifier 不会处理您的情况,除非您首先对目标标签进行阈值处理。您的另一个选择是,就像我上面所说的那样,使用标准回归技术(并将预测压缩到 [0,1] 之后)使用sklearn 创建一个回归模型,或者使用 Theano 创建一个自己的模型,例如,其中损失可能是 [0,1] 上的连续值预测和 [0,1] 上的连续值目标之间的交叉熵。 是的,可能后者对我来说是更好的方法

以上是关于软标签上的 scikit-learn 分类的主要内容,如果未能解决你的问题,请参考以下文章

使用 scikit-learn 进行多标签文本分类,使用哪些分类器?

scikit-learn 在多标签分类中计算 F1

GridSearch用于Scikit-learn中的多标签分类

如何告诉 scikit-learn 为哪个标签给出了 F-1/precision/recall 分数(在二进制分类中)?

多标签分类器中的拟合概率

软工大项目分工