软标签上的 scikit-learn 分类
Posted
技术标签:
【中文标题】软标签上的 scikit-learn 分类【英文标题】:scikit-learn classification on soft labels 【发布时间】:2017-08-05 15:52:10 【问题描述】:根据文档,可以为SGDClassifier
指定不同的损失函数。据我了解,log loss
是一个 cross-entropy
损失函数,理论上可以处理软标签,即以某些概率 [0,1] 给出的标签。
问题是:是否可以将SGDClassifier
与log loss
一起使用来解决带有软标签的分类问题?如果不是 - 如何使用 scikit-learn 解决这个任务(软标签上的线性分类)?
更新:
target
的标记方式以及问题的性质,硬标签不会产生好的结果。但这仍然是一个分类问题(不是回归),我不想保留对prediction
的概率解释,所以回归也不能开箱即用。交叉熵损失函数可以自然地处理target
中的软标签。 scikit-learn 中线性分类器的所有损失函数似乎都只能处理硬标签。
所以问题大概是:
例如,如何为SGDClassifier
指定我自己的损失函数。似乎scikit-learn
并没有坚持这里的模块化方法,需要在其源代码中的某个地方进行更改
【问题讨论】:
【参考方案1】:我最近遇到了这个问题,并想出了一个很好的解决方案,似乎可行。
基本上,使用反 sigmoid 函数将您的目标转换为对数赔率空间。然后拟合线性回归。然后,为了进行推理,从线性回归模型中获取预测的 sigmoid。
假设我们有软目标/标签y ∈ (0, 1)
(确保将目标钳制为[1e-8, 1 - 1e-8]
,以避免在我们记录日志时出现不稳定问题)。
我们取逆 sigmoid,然后拟合线性回归(假设预测变量在矩阵 X
中):
y = np.clip(y, 1e-8, 1 - 1e-8) # numerical stability
inv_sig_y = np.log(y / (1 - y)) # transform to log-odds-ratio space
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X, inv_sig_y)
然后进行预测:
def sigmoid(x):
ex = np.exp(x)
return ex / (1 + ex)
preds = sigmoid(lr.predict(X_new))
这似乎可行,至少对于我的用例而言。我的猜测是,无论如何,LogisticRegression 的幕后发生的事情并不遥远。
奖励:这似乎也适用于 sklearn
中的其他回归模型,例如RandomForestRegressor
.
【讨论】:
【参考方案2】:According to the docs,
“log”损失给出逻辑回归,一种概率分类器。
通常,损失函数的形式为Loss( prediction, target )
,其中prediction
是模型的输出,target
是真实值。在逻辑回归的情况下,prediction
是(0,1)
上的值(即“软标签”),而target
是0
或1
(即“硬标签”)。
所以在回答您的问题时,这取决于您指的是prediction
还是target
。一般来说,标签的形式(“硬”或“软”)由prediction
选择的算法和target
的现有数据给出。
如果您的数据具有“硬”标签,并且您希望模型输出“软”标签(可以对其进行阈值化以给出“硬”标签),那么是的,逻辑回归属于此类。
如果您的数据具有“软”标签,那么在使用典型分类方法(即逻辑回归)之前,您必须选择一个阈值以将其转换为“硬”标签。否则,您可以使用适合模型的回归方法来预测“软”目标。在后一种方法中,您的模型可能会给出(0,1)
之外的值,而这必须得到处理。
【讨论】:
感谢您的回复。我指的是target
。它的标签方式和问题的性质硬标签不会产生好的结果。但这仍然是一个分类问题(不是回归),我不想保留对prediction
的概率解释,所以回归也不能开箱即用。交叉熵损失函数可以自然处理target
中的软标签
所以不,据我所知,SGDClassifier
不会处理您的情况,除非您首先对目标标签进行阈值处理。您的另一个选择是,就像我上面所说的那样,使用标准回归技术(并将预测压缩到 [0,1] 之后)使用sklearn
创建一个回归模型,或者使用 Theano 创建一个自己的模型,例如,其中损失可能是 [0,1] 上的连续值预测和 [0,1] 上的连续值目标之间的交叉熵。
是的,可能后者对我来说是更好的方法以上是关于软标签上的 scikit-learn 分类的主要内容,如果未能解决你的问题,请参考以下文章
使用 scikit-learn 进行多标签文本分类,使用哪些分类器?
GridSearch用于Scikit-learn中的多标签分类
如何告诉 scikit-learn 为哪个标签给出了 F-1/precision/recall 分数(在二进制分类中)?