类的数量必须大于一;得了1分

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了类的数量必须大于一;得了1分相关的知识,希望对你有一定的参考价值。

我正在开发一个机器学习程序,我坚持这个错误。目前我的数据集有2个类,如下所示:

2652,0.09,-1.02,0.43,-0.01,-0.94,0.35,1
1,0.38,-0.90,0.19,0.30,0.95,0.12,2
2653,0.09,-1.02,0.43,-0.01,-0.94,0.35,1
4,0.38,-0.90,0.19,0.29,0.96,0.06,2
5,0.38,-0.90,0.19,0.29,0.96,0.06,2
2654,0.15,-1.01,0.45,-0.01,-0.94,0.35,1
2,0.38,-0.90,0.19,0.29,0.96,0.06,2

当我运行我的代码时,我得到了这个错误

ValueError                                Traceback (most recent call last)
<ipython-input-7-c44a67b01cf1> in <module>
     11 model, params = train_model(X_train, y_train, 
     12                     est=SVC(probability=True),
---> 13                     grid={'C': param_range, 'gamma': param_range, 'kernel': ['linear']})
     14 eval_model(model, X_test, y_test, 'SVC')
     15 

<ipython-input-5-d902442b6ba1> in train_model(X, y, est, grid)
      2     print('::::Train Model::::')
      3     gs = GridSearchCV(estimator=est, param_grid=grid, scoring='accuracy', cv=4, n_jobs=-1)
----> 4     gs = gs.fit(X, y)
      5 
      6     return (gs.best_estimator_, gs.best_params_)
.
.
.
ValueError: The number of classes has to be greater than one; got 1 class

但我已经意识到在这部分代码中

feats, y = get_simple_features(data, wsize='10s')
# split data into train and test sets

X_train, X_test, y_train, y_test = train_test_split(feats, y, test_size=.25, random_state=0, stratify=y)


print('Support Vector Machine')
model, params = train_model(X_train, y_train, 
                    est=SVC(probability=True),
                    grid={'C': param_range, 'gamma': param_range, 'kernel': ['linear']})
eval_model(model, X_test, y_test, 'SVC')

当我做print(np.unique(y))时,输出是[1]。它发生在这行代码中:

y = data['label'].resample(wsize, how=lambda ts: mode(ts)[0] if ts.shape[0] > 0 else np.nan)  

因为data ['label']有两个类,但重采样的结果只有1个类。但是,我已经要求另一个人运行我的代码,并且根本没有错误。

它能是什么?

PS:Here是完整的代码。

答案

这是因为在运行resample函数时你正在进行的重采样的随机性,特别是因为样本量太小(<10)并且它不是分层采样,你很可能得到一个只代表一个类的样本。

以上是关于类的数量必须大于一;得了1分的主要内容,如果未能解决你的问题,请参考以下文章

ValueError:类的数量必须大于一;得到 1

ValueError:类的数量必须大于一(python)

Hive Bucketing:不同列值的数量大于分桶数量

如何用servlet写一个简单的购物车系统

什么是在 C++ 中获取总内核数量的跨平台代码片段? [复制]

为啥尽管源代码没有变化,但从一个系统到另一个系统的片段数量却有很大差异?