sklearn 交叉验证中的自定义评分功能

Posted

技术标签:

【中文标题】sklearn 交叉验证中的自定义评分功能【英文标题】:Custom Scoring Function in sklearn Cross Validate 【发布时间】:2019-06-01 17:19:51 【问题描述】:

我想为cross_validate 使用一个自定义函数,它使用特定的y_test 来计算精度,这是一个不同于实际目标y_testy_test

我用make_scorer 尝试了一些方法,但我不知道如何真正通过我的替代y_test

scoring = 'prec1': 'precision',
     'custom_prec1': make_scorer(precision_score()

scores = cross_validate(pipeline, X, y, cv=5,scoring= scoring)

任何人都可以提出一种方法吗?

【问题讨论】:

【参考方案1】:

这样找到的。也许代码不是最优的,对此感到抱歉。

好的,让我们开始吧:

import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import GridSearchCV
from sklearn.metrics.scorer import make_scorer

xTrain = np.random.rand(100, 100)
yTrain = np.random.randint(1, 4, (100, 1))

yTrainCV = np.random.randint(1, 4, (100, 1))

model = LogisticRegression()

yTrainCV 将在此处用作自定义记分器。

def customLoss(xArray, yArray):
    indices = xArray.index.values
    tempArray = [1 if value1 != value2 else 0 for value1, value2 in zip(xArray.values, yTrainCV[[indices]])]

    return sum(tempArray)

scorer = 'main': 'accuracy',
          'custom': make_scorer(customLoss, greater_is_better=True)

这里有一些技巧:

您需要将值传递给 customLoss 2(来自模型的预测值 + 实际值;但我们不使用第二个参数) greater_is_better 有一些游戏:True/False 将返回正数或负数 我们从GridSearchCV 中的 CV 获得的指数

还有……

grid = GridSearchCV(model,
                    scoring=scorer,
                    cv=5,
                    param_grid='C': [1e0, 1e1, 1e2, 1e3],
                                'class_weight': ['balanced', None],
                    refit='custom')

 grid.fit(xTrain, pd.DataFrame(yTrain))
 print(grid.score(xTrain, pd.DataFrame(yTrain)))
不要忘记GridSearchCV中的refit参数 我们在这里将目标数组作为DataFrame 传递 - 这将帮助我们检测自定义损失函数中的索引

【讨论】:

非常感谢阿夫乔佐夫!这是完美的,我非常感谢你的帮助。这太棒了,我希望他们能将此作为 make_scorer 的 sklearn 文档中的示例 感谢@avchauzov 这个解决方案很棒,并且完全解决了没有官方答案的my own question。

以上是关于sklearn 交叉验证中的自定义评分功能的主要内容,如果未能解决你的问题,请参考以下文章

如何将交叉验证目标输入管道中的自定义转换器

使用 sklearn 进行交叉验证的高级特征提取

sklearn:文本分类交叉验证中的向量化

具有交叉验证的 Sklearn 线性回归返回 NA 准确度分数

sklearn:用户定义的时间序列数据交叉验证

如何使用 Sklearn 管道进行参数调整/交叉验证?