Sklearn:如何获得对训练数据进行分类的均方误差

Posted

技术标签:

【中文标题】Sklearn:如何获得对训练数据进行分类的均方误差【英文标题】:Sklearn: how to get mean squared error on classifying training data 【发布时间】:2018-07-12 16:45:37 【问题描述】:

我第一次尝试在 Python 中使用 sklearn 来解决一些分类问题,并且想知道仅根据训练数据计算分类器(如 SVM)的误差的最佳方法是什么。

我计算准确率和rmse的示例代码如下:

    svc = svm.SVC(kernel='rbf', C=C, decision_function_shape='ovr').fit(X_train, y_train.ravel())
    prediction = svc.predict(X_test)
    svm_in_accuracy.append(svc.score(X_train,y_train))
    svm_out_rmse.append(sqrt(mean_squared_error(prediction, np.array(list(y_test)))))
    svm_out_accuracy.append((np.array(list(y_test)) == prediction).sum()/(np.array(list(y_test)) == prediction).size)

我从 'sklearn.metrics import mean_squared_error' 知道几乎可以让我获得用于样本外比较的 MSE。我可以在 sklearn 中做些什么来给我一个关于我的模型在训练数据上错误分类的好/不好的错误度量?我问这个是因为我知道我的数据不是完全线性可分的(这意味着分类器会对某些项目进行错误分类),并且我想知道获得错误度量的最佳方法。任何帮助将不胜感激!

【问题讨论】:

用于分类。您可以使用准确率、召回率和精度 这非常广泛,取决于您的具体问题,而不是 sklearn 问题。首先,RMSE 仅用于回归。对于分类,请使用@AkshayNevrekar 的指标,或者另外使用 AUC 或 Log-Loss。实际研究混淆矩阵或 ROC-ruve 可能很有用。但这实际上取决于您的问题(类的数量,类的平衡,是误报还是误报更多的问题等)。 Sklearn 确实支持所有命名指标,请参阅here 【参考方案1】:

要评估您的分类器,您可以使用以下指标:

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

混淆矩阵将预测标签作为列标题,而真实标签是行标签。混淆矩阵的主对角线显示正确分配的标签数量。任何非对角元素都包含错误分配标签的数量。从混淆矩阵中,您还可以计算准确率、精确度和召回率。分类报告和混淆矩阵都易于使用 - 您将测试和预测标签传递给函数:

print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

[[1047    5]
 [   0  448]]

            precision    recall  f1-score   support

        0.0       1.00      1.00      1.00      1052
        1.0       0.99      1.00      0.99       448

avg / total       1.00      1.00      1.00      1500

其他指标函数计算并绘制 ROC 的接收器操作特征 (ROC) 和曲线下面积 (AUC)。您可以在此处阅读有关 ROC 的信息:

http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html

http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html

【讨论】:

以上是关于Sklearn:如何获得对训练数据进行分类的均方误差的主要内容,如果未能解决你的问题,请参考以下文章

说说最小均方误差(MMSE)

对时间序列数据作出指数平滑预测后,如何用excel计算数据的均方误差(MSE)?

如何在 sklearn 中使用训练有素的 NB 分类器预测电子邮件的标签?

在 sklearn 中使用 DictVectorizer 后如何获得分类特征的重要性

如何计算时间序列的均方误差?

如何在 Sklearn 的随机森林分类器中将训练模型用于另一个数据集?