Scikit-Learn 逻辑回归严重过拟合数字分类训练数据
Posted
技术标签:
【中文标题】Scikit-Learn 逻辑回归严重过拟合数字分类训练数据【英文标题】:Scikit-Learn's Logistic Regression severely overfits digit classification training data 【发布时间】:2021-01-09 23:30:03 【问题描述】:我正在使用 Scikit-Learn 的逻辑回归算法来执行数字分类。我使用的数据集是 Scikit-Learn 的 load_digits。
以下是我的代码的简化版本:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
digits = load_digits()
model = LogisticRegression(solver ='lbfgs',
penalty = 'none',
max_iter = 1e5,
multi_class = 'auto')
model.fit(digits.data, digits.target)
predictions = model.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = "size": 10)
ax.set_ylim(0,10)
plt.title("Confusion Matrix")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
train_size = [0.2, 0.4, 0.6, 0.8, 1]
training_size, training_score, validation_score = learning_curve(model, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
### EDIT ###
# With L2 regularization
model = LogisticRegression(solver ='lbfgs',
penalty = 'l2', # Changing penality to l2
max_iter = 1e5,
multi_class = 'auto')
model.fit(digits.data, digits.target)
predictions = model.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = "size": 10)
ax.set_ylim(0,10)
plt.title("Confusion Matrix with L2 regularization")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
training_size, training_score, validation_score = learning_curve(model, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.title("Learning curve with L2 regularization")
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
# With L2 regularization and best C
from sklearn.model_selection import GridSearchCV
C = 'C': [1e-3, 1e-2, 1e-1, 1, 10]
model_l2 = GridSearchCV(LogisticRegression(random_state = 0, solver ='lbfgs', penalty = 'l2', max_iter = 1e5, multi_class = 'auto'),
param_grid = C, cv = 5, iid = False, scoring = 'neg_mean_squared_error')
model_l2.fit(digits.data, digits.target)
best_C = model_l2.best_params_.get("C")
print(best_C)
model_reg = LogisticRegression(solver ='lbfgs',
penalty = 'l2',
C = best_C,
max_iter = 1e5,
multi_class = 'auto')
model_reg.fit(digits.data, digits.target)
predictions = model_reg.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = "size": 10)
ax.set_ylim(0,10)
plt.title("Confusion Matrix with L2 regularization and best C")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
training_size, training_score, validation_score = learning_curve(model_reg, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.title("Learning curve with L2 regularization and best C")
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
从训练数据的混淆矩阵和使用 learning_curve 生成的最后一个图中可以看出,训练集上的误差始终为 0:
Learning Curve Plot Here
在我看来,该模型严重过度拟合,我无法理解它。我也尝试过使用 MNIST 数据集,但发生了同样的事情。
我该如何解决这个问题?
-- 编辑--
在代码上方添加 L2 正则化,并为超参数 C 设置最佳值。
使用 L2 正则化,模型仍然过拟合数据:
Learning Curve with L2 regularization here
使用最好的 C 超参数,训练数据上的误差不再为零,但算法仍然过拟合:
Learning Curve with L2 regularization here and best C here
还是不明白怎么回事……
【问题讨论】:
【参考方案1】:使用正则化术语(惩罚)而不是“无”。
model = LogisticRegression(solver ='lbfgs',
penalty = 'l2',
max_iter = 1e5,
multi_class = 'auto')
您在验证曲线中找到的 C 的最佳值。
【讨论】:
感谢您的回复。仅仅将惩罚更改为“l2”并不会改变结果——模型继续过度拟合训练数据。在我的原始代码中,我使用 GridSearchCV 来找到 C 的最佳值。我在编辑中添加了上面的相应代码。训练集上的误差不再为零,但这对我来说仍然没有意义 - 即使对于如此小的训练数据集大小,即使有或没有正则化(并且没有最佳 C),误差怎么可能为零? 是的,如果您不指定它,则值为 C=1.0(这就是它在文档中所说的)。所以你在什么时间间隔内进行网格搜索?也许你使用的间隔太窄了。我建议对 GridSearchCV 执行此操作:parameters = "C": np.logspace(-4, 4, 50) #Define grid.以上是关于Scikit-Learn 逻辑回归严重过拟合数字分类训练数据的主要内容,如果未能解决你的问题,请参考以下文章