在 scikit-learn 库中使用 sgd 求解器的 SGDClassifier 与 LogisticRegression

Posted

技术标签:

【中文标题】在 scikit-learn 库中使用 sgd 求解器的 SGDClassifier 与 LogisticRegression【英文标题】:SGDClassifier vs LogisticRegression with sgd solver in scikit-learn library 【发布时间】:2017-10-13 03:50:55 【问题描述】:

scikit-learn 库具有以下看起来相似的分类器:

逻辑回归分类器有不同的求解器,其中之一 是'sgd'

http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression

它还有一个不同的分类器'SGDClassifier'和损失 对于逻辑回归,参数可以称为“log”。

http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html#sklearn.linear_model.SGDClassifier

它们本质上是相同的还是不同的?如果它们不同,那么两者之间的实现有何不同?考虑到逻辑回归问题,您如何决定使用哪一个?

【问题讨论】:

LogisticRegression 模块没有 SGD 算法('newton-cg'、'lbfgs'、'liblinear'、'sag'),但模块 SGDClassifier 也可以解决 LogisticRegression。这意味着您有 5 个可以使用的求解器。这些之间存在巨大差异,并且文档中给出了一些可供选择的规则(例如,第 1 组中的哪一个)。 SGD 通常用于非常有效的大规模问题。与其他人相比,它可能非常依赖于选择的超参数(学习率、衰减……)。糟糕的超参数不仅会导致性能下降,还会导致糟糕的结果(未达到全局最小值) 谢谢。我对“sag”和“sgd”感到困惑。 “sag”是指随机平均梯度吗?我认为这类似于 sgd,除非随机平均梯度与随机平均梯度下降有很大不同。 @sascha "SAG" 代表“S随机A平均G辐射下降”。见scikit-learn.org/stable/modules/… 【参考方案1】:

Sklearn 中的逻辑回归虽然没有“sgd”求解器。它实现了对数正则化逻辑回归:它最小化了对数概率。

SGDClassifier 是一个广义线性分类器,它将使用随机梯度下降作为求解器。正如这里提到的 http://scikit-learn.org/stable/modules/sgd.html :“尽管 SGD 在机器学习社区中已经存在了很长时间,但它最近在大规模学习的背景下受到了相当多的关注。” 它易于实施且高效。例如,这是用于神经网络的求解器之一。

借助 SGDClassifier,您可以使用许多不同的损失函数(一个最小化或最大化以找到最佳解决方案的函数),让您可以“调整”您的模型并为您的数据找到基于 sgd 的最佳线性模型。确实,某些数据结构或某些问题需要不同的损失函数。

在您的示例中,SGD 分类器将具有与 Logistic 回归相同的损失函数,但求解器不同。根据您的数据,您可以得到不同的结果。你可以尝试使用交叉验证找到最好的,甚至尝试网格搜索交叉验证来找到最好的超参数。

希望能回答您的问题。

【讨论】:

非常好的答案! 感谢您说得这么清楚!我没有将“损失”中的“日志”与逻辑回归联系起来!【参考方案2】:

基本上,SGD 就像一把伞,能够面对不同的线性函数。 SGD 是一种近似算法,例如采用单个单点,随着点数的增加,它会更多地转换为最优解。因此,它主要用于数据集较大的情况。 逻辑回归默认使用梯度下降,因此速度较慢(如果在大型数据集上进行比较) 为了让 SGD 对任何特定的线性函数表现良好,让我们在这里说逻辑回归,我们调整称为超参数调整的参数

【讨论】:

【参考方案3】:

所有线性分类器(SVM、逻辑回归、a.o.)都可以使用 sgd: Stochastic Gradient Descent

【讨论】:

以上是关于在 scikit-learn 库中使用 sgd 求解器的 SGDClassifier 与 LogisticRegression的主要内容,如果未能解决你的问题,请参考以下文章

用随机梯度下降法(SGD)做线性拟合

SciKit SGD 回归器 RBF 内核逼近

SciKit-Learn使用什么方法来进行基本的LinearRegression?

在 scikit-learn 中训练神经网络时提前停止

Python Scikit-Learn 库中分类数据的异常值预测

在 sklearn 中,具有线性内核的 SVM 模型和具有 loss=hinge 的 SGD 分类器有啥区别