MLP 分类器:“ValueError:未知标签类型”
Posted
技术标签:
【中文标题】MLP 分类器:“ValueError:未知标签类型”【英文标题】:MLP Classifier: "ValueError: Unknown label type" 【发布时间】:2019-06-04 00:47:11 【问题描述】:我正在尝试使用 MLP 分类器创建一个基本的 NN。
当我使用方法mlp.fit
a 得到以下错误:
ValueError: 未知标签类型:(array([
下面是我的简单代码
df_X_train = df_train[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_train = df_train["Eff_Th"]
df_X_test = df_test[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_test = df_test["Eff_Th"]
X_train = np.asarray(df_X_train, dtype="float64")
Y_train = np.asarray(df_Y_train, dtype="float64")
X_test = np.asarray(df_X_test, dtype="float64")
Y_test = np.asarray(df_Y_test, dtype="float64")
from sklearn.neural_network import MLPClassifier
mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
mlp.fit(X_train, Y_train)
其实我不明白为什么fit
这个方法不喜欢X_train
和Y_train
的float类型。
只是为了让矩阵维度下的一切都清楚:
X_train.shape --> (720, 3)
Y_train.shape --> (720,)
我希望我以正确的方式提问,谢谢。
下面是完整的错误:
> --------------------------------------------------------------------------- ValueError Traceback (most recent call
> last) <ipython-input-6-2efb224ab852> in <module>()
> 2
> 3 mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
> ----> 4 mlp.fit(X_train, Y_train)
> 5
> 6 #y_pred_train = mlp.predict(X_train)
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in fit(self, X, y)
> 971 """
> 972 return self._fit(X, y, incremental=(self.warm_start and
> --> 973 hasattr(self, "classes_")))
> 974
> 975 @property
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _fit(self, X, y, incremental)
> 329 hidden_layer_sizes)
> 330
> --> 331 X, y = self._validate_input(X, y, incremental)
> 332 n_samples, n_features = X.shape
> 333
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _validate_input(self, X, y, incremental)
> 914 if not incremental:
> 915 self._label_binarizer = LabelBinarizer()
> --> 916 self._label_binarizer.fit(y)
> 917 self.classes_ = self._label_binarizer.classes_
> 918 elif self.warm_start:
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\preprocessing\label.py
> in fit(self, y)
> 282
> 283 self.sparse_input_ = sp.issparse(y)
> --> 284 self.classes_ = unique_labels(y)
> 285 return self
> 286
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\utils\multiclass.py
> in unique_labels(*ys)
> 94 _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
> 95 if not _unique_labels:
> ---> 96 raise ValueError("Unknown label type: %s" % repr(ys))
> 97
> 98 ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys))
>
> ValueError: Unknown label type: (array([1. , 0.89534884, 0.58139535, 0.37209302, 0.24418605,
0.15116279, 0.09302326, 0.23255814, 0.34883721, 0.37209302,
0.30232558, 0.23255814, 0.18604651, 0.12790698, 0.08139535,
0.08139535, 0.19767442, 0.27906977, 0.26744186, 0.22093023,
0.1744186 , 0.11627907, 0.06976744, 0.05813953, 0.1744186 ,
0.26744186, 0.34883721, 0.40697674, 0.46511628, 0.45348837,
0.38372093, 0.31395349, 0.26744186, 0.36046512, 0.44186047,
0.48837209, 0.53488372, 0.48837209, 0.40697674, 0.31395349,
0.24418605, 0.1744186 , 0.19767442, 0.29069767, 0.36046512,
0.3255814 , 0.26744186, 0.20930233, 0.13953488, 0.09302326,
0.04651163, 0.09302326, 0.19767442, 0.29069767, 0.26744186,
0.20930233, 0.1627907 , 0.11627907, 0.06976744, 0.03488372,
0.12790698, 0.24418605, 0.31395349, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
0.25581395, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
0.09302326, 0.05813953, 0.04651163, 0.1627907 , 0.26744186,
0.30232558, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
0.05813953, 0.06976744, 0.18604651, 0.27906977, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.10465116, 0.22093023, 0.29069767, 0.26744186, 0.22093023,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.12790698,
0.24418605, 0.30232558, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.03488372, 0.15116279, 0.26744186,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.09302326,
0.05813953, 0.09302326, 0.20930233, 0.29069767, 0.26744186,
0.22093023, 0.1627907 , 0.11627907, 0.06976744, 0.02325581,
0.12790698, 0.23255814, 0.31395349, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
0.25581395, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.02325581, 0.11627907, 0.22093023,
0.29069767, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
0.04651163, 0.02325581, 0.10465116, 0.20930233, 0.30232558,
0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.05813953,
0.03488372, 0.13953488, 0.24418605, 0.31395349, 0.25581395,
0.20930233, 0.15116279, 0.10465116, 0.15116279, 0.26744186,
0.3372093 , 0.36046512, 0.30232558, 0.24418605, 0.19767442,
0.1744186 , 0.25581395, 0.3255814 , 0.38372093, 0.41860465,
0.34883721, 0.29069767, 0.23255814, 0.1627907 , 0.1744186 ,
0.27906977, 0.34883721, 0.3255814 , 0.26744186, 0.20930233,
0.15116279, 0.09302326, 0.04651163, 0.10465116, 0.22093023,
0.30232558, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
0.05813953, 0.02325581, 0.12790698, 0.24418605, 0.30232558,
0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.1627907 ,
0.26744186, 0.37209302, 0.45348837, 0.51162791, 0.55813953,
0.59302326, 0.62790698, 0.56976744, 0.48837209, 0.40697674,
0.36046512, 0.43023256, 0.47674419, 0.48837209, 0.39534884,
0.30232558, 0.23255814, 0.1627907 , 0.10465116, 0.19767442,
0.29069767, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.02325581, 0.03488372, 0.15116279,
0.25581395, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.26744186,
0.25581395, 0.20930233, 0.1627907 , 0.11627907, 0.06976744,
0.03488372, 0. , 0.10465116, 0.20930233, 0.27906977,
0.22093023, 0.1744186 , 0.12790698, 0.08139535, 0.08139535,
0.19767442, 0.29069767, 0.36046512, 0.43023256, 0.48837209,
0.53488372, 0.56976744, 0.60465116, 0.52325581, 0.45348837,
0.38372093, 0.45348837, 0.51162791, 0.54651163, 0.54651163,
0.44186047, 0.36046512, 0.27906977, 0.20930233, 0.1744186 ,
0.25581395, 0.3372093 , 0.3372093 , 0.27906977, 0.22093023,
0.1627907 , 0.10465116, 0.05813953, 0.06976744, 0.18604651,
0.27906977, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.03488372, 0.10465116, 0.22093023, 0.30232558,
0.27906977, 0.22093023, 0.1744186 , 0.11627907, 0.19767442,
0.29069767, 0.36046512, 0.40697674, 0.34883721, 0.29069767,
0.23255814, 0.1744186 , 0.20930233, 0.30232558, 0.36046512,
0.34883721, 0.29069767, 0.23255814, 0.1744186 , 0.11627907,
0.06976744, 0.11627907, 0.22093023, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.12790698,
0.24418605, 0.3255814 , 0.27906977, 0.23255814, 0.1744186 ,
0.12790698, 0.08139535, 0.03488372, 0. , 0.11627907,
0.22093023, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.04651163, 0.02325581, 0.11627907, 0.23255814,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.08139535, 0.19767442, 0.29069767, 0.29069767,
0.23255814, 0.18604651, 0.13953488, 0.08139535, 0.04651163,
0.06976744, 0.18604651, 0.27906977, 0.27906977, 0.23255814,
0.1744186 , 0.12790698, 0.08139535, 0.04651163, 0.12790698,
0.24418605, 0.3255814 , 0.27906977, 0.22093023, 0.1744186 ,
0.11627907, 0.06976744, 0.03488372, 0.13953488, 0.24418605,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.02325581, 0.13953488, 0.24418605, 0.26744186,
0.22093023, 0.1744186 , 0.12790698, 0.06976744, 0.03488372,
0.08139535, 0.19767442, 0.27906977, 0.29069767, 0.24418605,
0.19767442, 0.13953488, 0.09302326, 0.11627907, 0.23255814,
0.3255814 , 0.30232558, 0.25581395, 0.19767442, 0.15116279,
0.09302326, 0.04651163, 0.08139535, 0.19767442, 0.27906977,
0.31395349, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.03488372, 0.15116279, 0.25581395, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.01162791,
0.12790698, 0.23255814, 0.31395349, 0.29069767, 0.24418605,
0.18604651, 0.13953488, 0.09302326, 0.05813953, 0.1744186 ,
0.27906977, 0.34883721, 0.29069767, 0.23255814, 0.1744186 ,
0.11627907, 0.06976744, 0.09302326, 0.19767442, 0.30232558,
0.31395349, 0.26744186, 0.20930233, 0.15116279, 0.10465116,
0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.08139535, 0.20930233, 0.29069767, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.09302326,
0.20930233, 0.27906977, 0.23255814, 0.18604651, 0.13953488,
0.09302326, 0.04651163, 0.05813953, 0.18604651, 0.26744186,
0.3372093 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.1744186 , 0.27906977, 0.34883721, 0.30232558,
0.24418605, 0.18604651, 0.13953488, 0.08139535, 0.03488372,
0.04651163, 0.1627907 , 0.26744186, 0.26744186, 0.22093023,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.03488372,
0.15116279, 0.25581395, 0.27906977, 0.22093023, 0.1744186 ,
0.12790698, 0.08139535, 0.03488372, 0.01162791, 0.12790698,
0.23255814, 0.29069767, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.05813953, 0.05813953, 0.1744186 , 0.27906977,
0.29069767, 0.24418605, 0.18604651, 0.13953488, 0.09302326,
0.11627907, 0.23255814, 0.30232558, 0.34883721, 0.29069767,
0.24418605, 0.18604651, 0.12790698, 0.15116279, 0.25581395,
0.3255814 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.12790698, 0.22093023, 0.30232558, 0.25581395,
0.20930233, 0.1627907 , 0.11627907, 0.05813953, 0.02325581,
0.05813953, 0.1744186 , 0.26744186, 0.22093023, 0.1744186 ,
0.12790698, 0.08139535, 0.04651163, 0.01162791, 0.11627907,
0.22093023, 0.25581395, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.03488372, 0.08139535, 0.19767442, 0.27906977,
0.34883721, 0.29069767, 0.24418605, 0.18604651, 0.13953488,
0.10465116, 0.22093023, 0.30232558, 0.3255814 , 0.27906977,
0.22093023, 0.1627907 , 0.10465116, 0.05813953, 0.02325581,
0.12790698, 0.24418605, 0.29069767, 0.24418605, 0.19767442,
0.13953488, 0.09302326, 0.05813953, 0.02325581, 0.10465116,
0.22093023, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
0.09302326, 0.05813953, 0.02325581, 0.06976744, 0.18604651,
0.27906977, 0.25581395, 0.20930233, 0.1627907 , 0.10465116,
0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.25581395,
0.3255814 , 0.38372093, 0.44186047, 0.41860465, 0.34883721,
0.29069767, 0.24418605, 0.25581395, 0.34883721, 0.41860465,
0.46511628, 0.5 , 0.51162791, 0.41860465, 0.3372093 ,
0.26744186, 0.20930233, 0.20930233, 0.30232558, 0.37209302,
0.36046512, 0.29069767, 0.22093023, 0.15116279, 0.10465116,
0.09302326, 0.19767442, 0.27906977, 0.25581395, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.02325581, 0.08139535,
0.19767442, 0.26744186, 0.22093023, 0.1744186 , 0.13953488,
0.09302326, 0.04651163, 0.02325581, 0.13953488, 0.24418605,
0.26744186, 0.22093023, 0.1744186 , 0.12790698, 0.08139535,
0.1744186 , 0.26744186, 0.34883721, 0.40697674, 0.46511628,
0.41860465, 0.34883721, 0.27906977, 0.22093023, 0.18604651,
0.27906977, 0.34883721, 0.37209302, 0.30232558, 0.24418605,
0.1744186 , 0.11627907, 0.06976744, 0.03488372, 0.15116279]),)
【问题讨论】:
您能检查一下type(X_train)
和type(Y_train)
返回的内容吗?也许您的数组被包装在不同的数据类型中
我不确定这是否会导致您的问题,但它可能仍然是一个问题:我认为 MLPClassifier.fit 也希望 Y 是矩形,即在您的情况下具有形状 (720, 1 ) 而不是 (720,)。您可以轻松地将df_Y_train = df_train["Eff_Th"]
替换为df_Y_train = df_train[["Eff_Th"]]
。
我认为在你的问题中,你想写Y_train.shape --> (720,)
@MarcoSpinaci:只要尺寸匹配,就不会造成任何麻烦。试试X_train = np.array([[1,2,3],[10,20,30],[100,200,300]])
和Y_train = np.array([3, 30, 300])
你会发现它有效。我认为这个问题与X_train和Y_train的数据类型有关。
@offeltoffel 谢谢你的回答,变量 X_train 和 Y_train 都是 dtype('float64')
【参考方案1】:
当你说你的目标变量需要浮点精度时,看起来你需要MLPRegressor 而不是 MLPClassifier。
【讨论】:
没错。现在有了完整的回溯,我们可以看到 MLPC 正在尝试将浮点值作为标签来学习。在原始问题中,错误消息似乎在(array([
之后结束
谢谢你们,MLPRegressor 工作正常。以上是关于MLP 分类器:“ValueError:未知标签类型”的主要内容,如果未能解决你的问题,请参考以下文章
基于Halcon的MLP(多层感知神经网络)分类器分类操作实例