python sklearn:使用交叉验证调整参数
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python sklearn:使用交叉验证调整参数相关的知识,希望对你有一定的参考价值。
"""
This file uses cross validation to tune the parameters.
"""
import multiprocessing
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, ParameterGrid
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
from model.hmm import ConstrainedMixHMM
from utils.utils import load_data, send_email
sns.set_style('whitegrid')
np.set_printoptions(precision=3)
pd.set_option('precision', 3)
task_name = "HMMCrossValidation"
input_folder = '../data/output'
output_folder = '../data/output/%s' % task_name
figure_folder = '../figure/%s' % task_name
data_config = dict(
path='%s/RNZS_CC_F_include_missing.csv' % (input_folder),
aMCI_only=True,
split_aMCI_naMCI=False,
split_early_clinic_MCI=True,
include_missing_x=True,
include_missing_y=True,
impute_missing=False,
reviewed_label_only=False
)
run_config = dict(
run_times=10,
verbose=False,
plot=False
)
model_config = dict(
monotonic_state=True,
covariance_type='diag',
)
param_grid = {
'transmat_type': ['upper-bidiagonal'],
'n_components': range(3, 15) + [17, 20, 24]
}
def _fit(X, lengths, random_state, config):
"""fit the model"""
return ConstrainedMixHMM(random_state=random_state, **config).fit(X, lengths)
def fit_model(parallel, X, lengths, config):
""" fit model with multiple times using different random starts
config : dict
model config
"""
if run_config['verbose']:
print "Fitting the model..."
models = parallel(
delayed(_fit)(X, lengths, i, config)
for i in range(run_config['run_times'])
)
scores = [m.monitor_.history[-1] for m in models]
if run_config['plot']:
plt.figure()
sns.kdeplot(np.array(scores))
plt.plot(np.max(scores), 0, 'ro')
plt.title('kernel density esitmation of log-likelihood')
return models[np.argmax(scores)]
def run_cross_validation(n_splits, notification=False):
""" tune the number of hidden states using cross_validation
Parameters
----------
n_splits : int
number of folds for cross validation
notification : boolean
Whether to notify the progress through wechat
"""
X = load_data(**data_config)
path = "%s/cv_splits.pkl" % output_folder
try:
k_splits = joblib.load(path)
except IOError:
k_splits = list(KFold(n_splits, shuffle=True).split(X.index.levels[0]))
joblib.dump(k_splits, path)
print k_splits
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
with Parallel(n_jobs=n_jobs) as parallel:
for params in tqdm(ParameterGrid(param_grid)):
print "Fitting %r using cross-validation..." % params
for i, (train_index, _) in tqdm(enumerate(k_splits)):
DBID_train = X.index.levels[0][train_index].tolist()
X_train = X.loc[DBID_train]
lengths_train = X_train.groupby(level=0).size().values
# load or traint the model
path = "%s/cv=%d_transmat=%s_n=%d.pkl" % (
output_folder, i, params['transmat_type'], params['n_components'])
try:
model = ConstrainedMixHMM.load(path)
except IOError:
if notification:
send_email("Cross Validation",
"Start fitting %d-fold with %r." % (i, params))
model_config['n_components'] = params['n_components']
model_config['transmat_type'] = params['transmat_type']
model = fit_model(parallel, X_train,
lengths_train, model_config)
model.save(path)
if notification:
send_email("Cross Validation",
"Finished fitting %d-fold with %r." % (i, params))
def fit_model_on_whole_dataset():
X = load_data(**data_config)
lengths = X.groupby(level=0).size().values
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
with Parallel(n_jobs=n_jobs) as parallel:
for params in tqdm(ParameterGrid(param_grid)):
print "fitting %r ..." % params
# load or traint the model
filename = "%s/model_type=%s_n=%d.pkl" % (output_folder,
params['transmat_type'], params['n_components'])
try:
model = ConstrainedMixHMM.load(filename)
except IOError:
model_config['n_components'] = params['n_components']
model_config['transmat_type'] = params['transmat_type']
model = fit_model(parallel, X, lengths, model_config)
model.save(filename)
if __name__ == "__main__":
assert len(sys.argv) == 2
if "CV" in sys.argv[1]:
run_cross_validation(n_splits=5, notification=True)
if "WHOLE" in sys.argv[1]:
fit_model_on_whole_dataset()
以上是关于python sklearn:使用交叉验证调整参数的主要内容,如果未能解决你的问题,请参考以下文章
如何将 KerasClassifier、Hyperopt 和 Sklearn 交叉验证放在一起