SVC 的网格搜索:IndexError:数组索引过多

Posted

技术标签:

【中文标题】SVC 的网格搜索:IndexError:数组索引过多【英文标题】:Grid search for SVC : IndexError: too many indices for array 【发布时间】:2019-10-16 07:36:54 【问题描述】:

我正在尝试使用 GridSearchCVSVC 找到最佳参数。

from sklearn.svm import SVC
from sklearn import svm, grid_search
from sklearn.model_selection import GridSearchCV

param_grid = [
        'C': [1,5,10,100],
        ]
algo = SVC(kernel="poly",  degree=5, coef0=2)
grid_search = GridSearchCV(algo, param_grid, cv=3, scoring='neg_mean_squared_error')
grid_search.fit(X_train, y_train)
print(grid_search.best_params_) #line 162

我收到以下错误:

  File "main.py", line 162, in <module>
  IndexError: too many indices for array

当我不使用 GridSearchCV 时,它可以工作:

from sklearn.svm import SVC
from sklearn import svm, grid_search
from sklearn.model_selection import GridSearchCV

algo = SVC(kernel="poly", C=1, degree=5, coef0=2)
algo.fit(X_train, y_train)
predict_test = algo.predict(X_test)
mse = mean_squared_error(y_test, predict_test)
rmse = np.sqrt(mse)
print(rmse)

我得到一个分数。

【问题讨论】:

y_train.shape 的输出是什么? y_train.shape的输出是(892, 1),X_train.shape的输出是(892, 14) grid_search.fit(X_train, y_train) 之前使用y_train = y_train.reshape(892,)。现在有什么错误吗? 是的!谢谢!出了什么问题?为什么 (892,) 有效(没有其他属性),而 (892,1) 无效? 【参考方案1】:

GridSearchCV.fit() 接受目标值作为类似数组的y,形状为[n_samples][n_samples, n_output]

在你的情况下,(892,)。因此,重塑y_train

y_train = y_train.reshape(892,)

【讨论】:

以上是关于SVC 的网格搜索:IndexError:数组索引过多的主要内容,如果未能解决你的问题,请参考以下文章

IndexError:数组的索引过多:数组是二维的,但有 3 个被索引

SVC 的网格搜索问题 - 如何排除故障?

如何解决“IndexError:数组索引过多”

IndexError:布尔索引与索引数组不匹配

IndexError:数组的索引过多

IndexError:数组的索引过多。具有 42 个特征的 Numpy 数组不均匀