如何在 keras 分类器中使用交叉验证

Posted

技术标签:

【中文标题】如何在 keras 分类器中使用交叉验证【英文标题】:How to use cross validation in keras classifier 【发布时间】:2021-01-17 11:31:41 【问题描述】:

我正在练习不平衡数据的 keras 分类。我按照官方的例子:

https://keras.io/examples/structured_data/imbalanced_classification/

并使用 scikit-learn api 进行交叉验证。 我已经尝试了具有不同参数的模型。 但是,3 个折叠中的一个始终为 0。

例如。

results [0.99242424 0.99236641 0.        ]

我做错了什么? 如何获取订单“0.8”的所有三个验证召回值?

MWE

%%time
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold

import os
import random
SEED = 100
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)

# load the data
ifile = "https://github.com/bhishanpdl/Datasets/blob/master/Projects/Fraud_detection/raw/creditcard.csv.zip?raw=true"
df = pd.read_csv(ifile,compression='zip')

# train test split
target = 'Class'
Xtrain,Xtest,ytrain,ytest = train_test_split(df.drop([target],axis=1),
    df[target],test_size=0.2,stratify=df[target],random_state=SEED)

print(f"Xtrain shape: Xtrain.shape")
print(f"ytrain shape: ytrain.shape")


# build the model
def build_fn(n_feats):
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(256, activation="relu", input_shape=(n_feats,)))
    model.add(keras.layers.Dense(256, activation="relu"))
    model.add(keras.layers.Dropout(0.3))
    model.add(keras.layers.Dense(256, activation="relu"))
    model.add(keras.layers.Dropout(0.3))

    # last layer is dense 1 for binary sigmoid
    model.add(keras.layers.Dense(1, activation="sigmoid"))

    # compile
    model.compile(loss='binary_crossentropy',
                optimizer=keras.optimizers.Adam(1e-2),
                metrics=['Recall'])

    return model

# fitting the model
n_feats      = Xtrain.shape[-1]
counts = np.bincount(ytrain)
weight_for_0 = 1.0 / counts[0]
weight_for_1 = 1.0 / counts[1]
class_weight = 0: weight_for_0, 1: weight_for_1
FIT_PARAMS   = 'class_weight' : class_weight

clf_keras = KerasClassifier(build_fn=build_fn,
                            n_feats=n_feats, # custom argument
                            epochs=30,
                            batch_size=2048,
                            verbose=2)

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=SEED)
results = cross_val_score(clf_keras, Xtrain, ytrain,
                          cv=skf,
                          scoring='recall',
                          fit_params = FIT_PARAMS,
                          n_jobs = -1,
                          error_score='raise'
                          )

print('results', results)

结果

Xtrain shape: (227845, 30)
ytrain shape: (227845,)
results [0.99242424 0.99236641 0.        ]
CPU times: user 3.62 s, sys: 117 ms, total: 3.74 s
Wall time: 5min 15s

问题

我得到的第三次召回为 0。我期望它的顺序为 0.8,如何确保所有三个值都在 0.8 左右或更高?

【问题讨论】:

您确定,优化召回是您想要的吗?请注意,recall=t_p/(t_p+f_n)。如果您的模型仅学习预测正类,则召回率等于最大值 1。使用这样一个经过训练的分类器,您可能会得到很多误报,但它们根本不会影响分数。 这是为了学习目的,我们也可以使用F1-score作为指标。 error_score='raise 放入cross_val_score - 这将显示错误。出于某种原因,scikit-learn 默认会抑制该函数内部的错误。 【参考方案1】:

银河001,

您已选择为您的模型使用 sklearn 包装器 - 它们有好处,但模型训练过程是隐藏的。相反,我使用添加的验证数据集单独训练模型。代码如下:

clf_1 = KerasClassifier(build_fn=build_fn,
                       n_feats=n_feats)

clf_1.fit(Xtrain, ytrain, class_weight=class_weight,
          validation_data=(Xtest, ytest),
          epochs=30,batch_size=2048,
          verbose=1)
     

Model.fit() 输出中可以清楚地看到,虽然损失指标下降,但召回并不稳定。正如您所观察到的,这会导致 CV 表现不佳,反映在 CV 结果为零。

我通过将学习率降低到 0.0001 来解决这个问题。虽然它比你的少 100 倍 - 它在训练中达到 98% 的召回率,在测试中达到 100%(或接近)仅 10 个 epoch。

您的代码只需要一次修复即可获得稳定的结果:将 LR 更改为低得多的值,例如 0.0001:

optimizer=keras.optimizers.Adam(1e-4),

您可以在 0.0001 我得到了:

results [0.99242424 0.97709924 1.        ]

祝你好运!

PS:感谢您提供紧凑而完整的 MWE

【讨论】:

以上是关于如何在 keras 分类器中使用交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

Keras训练神经网络进行分类并进行交叉验证(Cross Validation)

如何对目录中的 keras 图像数据集使用交叉验证?

如何在交叉验证中获得 Keras scikit-learn 包装器的训练和验证损失?

用于二进制分类的 ResNet - 只有 2 个交叉验证准确度值

使用 Keras 和 sklearn GridSearchCV 交叉验证提前停止

如何进行交叉验证 SVM 分类器