BayesSearchCV 在 SGDClassifier 参数调整期间不起作用

Posted

技术标签:

【中文标题】BayesSearchCV 在 SGDClassifier 参数调整期间不起作用【英文标题】:BayesSearchCV is not working during SGDClassifier parameter tuning 【发布时间】:2020-10-17 07:59:37 【问题描述】:

我正在尝试使用 BayesSearchCV 来调整 SGDClassifier 的参数。以下是我尝试过的代码。

import seaborn
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from skopt import BayesSearchCV
from sklearn.linear_model import SGDClassifier

df = seaborn.load_dataset("iris")
df_features = df.drop(['species'], axis=1)
df_target = df[['species']]

label_encoder = LabelEncoder()
df_target['species'] = list(label_encoder.fit_transform(df['species'].values.tolist()))

X_train, X_test, y_train, y_test = train_test_split(df_features, df_target, test_size=0.25, random_state=0)

model = SGDClassifier()

model_param = 
    'penalty': ['l2', 'l1', 'elasticnet'],
    'l1_ratio': [0, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 1],
    'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge', 'perceptron', 'squared_loss', 'huber',
             'epsilon_insensitive', 'squared_epsilon_insensitive'],
    'alpha': [10 ** x for x in range(-6, 1)],
    'random_state': [0]


opt = BayesSearchCV(model, model_param, n_iter=32, cv=3)
opt.fit(X_train, y_train)
opt_pred_values = opt.predict(X_test)

正在创建以下错误:

ValueError: invalid literal for int() with base 10: '0.8'

我还使用相同的 model_param 列表测试了 GridSearchCV 和 RandomizedSearchCV,它们工作正常。如何正确使用 BayesSearchCV?我必须在哪里更改或必须删除哪个参数?

[更新]

如果我从 model_param 中删除“l1_ratio”,那么上面的代码就可以工作了。如何执行保持'l1_ratio'?

【问题讨论】:

【参考方案1】:

经过几次参数组合后,我发现如果我删除 'l1_ratio' 就可以了。然后我尝试了如下所示的“l1_ratio”:

'l1_ratio': [0.0, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 1.0]
'l1_ratio': [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.85, 0.9, 1]
'l1_ratio': [10 ** x for x in range(-1, 1)]
'l1_ratio': [float(x/10) for x in range(1, 10)]

一切正常。所以最后我在 'l1_ratio' 的搜索空间里把 0 改成了 0.0 和 1 改成了 1.0。

我将解决方案保留在这里以备将来使用。也许有一天有人会受益。

【讨论】:

以上是关于BayesSearchCV 在 SGDClassifier 参数调整期间不起作用的主要内容,如果未能解决你的问题,请参考以下文章

skopt BayesSearchCV 中的 n_points 是如何工作的?

如何为 GridSearchCV 提供交叉验证的索引列表?

分配的变量引用在哪里,在堆栈中还是在堆中?

NOIP 2015 & SDOI 2016 Round1 & CTSC 2016 & SDOI2016 Round2游记

秋的潇洒在啥?在啥在啥?

上传的数据在云端的怎么查看,保存在啥位置?