使用 MLPRegressor 拟合简单数据时遇到问题
Posted
技术标签:
【中文标题】使用 MLPRegressor 拟合简单数据时遇到问题【英文标题】:Trouble fitting simple data with MLPRegressor 【发布时间】:2017-04-25 11:48:43 【问题描述】:我正在尝试 Python 和 scikit-learn。我无法让 MLPRegressor 更接近数据。这是哪里出了问题?
from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0.0, 1, 0.01).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()
reg = MLPRegressor(hidden_layer_sizes=(10,), activation='relu', solver='adam', alpha=0.001,batch_size='auto',
learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True,
random_state=None, tol=0.0001, verbose=False, warm_start=False, momentum=0.9,
nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
epsilon=1e-08)
reg = reg.fit(x, y)
test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1)
test_y = reg.predict(test_x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=10, c='b', marker="s", label='real')
ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction')
plt.show()
结果不是很好: 谢谢。
【问题讨论】:
这是我使用 Keras (TensorFlow) 和 Gekko 编写的教程,其中包含简单的正弦波示例:apmonitor.com/do/index.php/Main/DeepLearning 这对于那些想要与发布的已经非常好的解决方案进行比较的人可能很有用在这里。 【参考方案1】:你只需要
将求解器更改为'lbfgs'
。 default'adam'
是一种类似 SGD 的方法,对于大而杂乱的数据很有效,但对于这种平滑且小的数据就没什么用了。
使用平滑激活函数,例如tanh
。 relu
几乎是线性的,不适合学习这种简单的非线性函数。
这是result 和完整代码。即使只有 3 个隐藏的神经元也能达到非常高的准确率。
from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0.0, 1, 0.01).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()
nn = MLPRegressor(hidden_layer_sizes=(3),
activation='tanh', solver='lbfgs')
n = nn.fit(x, y)
test_x = np.arange(-0.1, 1.1, 0.01).reshape(-1, 1)
test_y = nn.predict(test_x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=5, c='b', marker="o", label='real')
ax1.plot(test_x,test_y, c='r', label='NN Prediction')
plt.legend()
plt.show()
【讨论】:
我同意你的回答,它更符合曲线。我可以看到为什么 ReLu 不适合连续函数,但是亚当对这种“平滑和小数据”问题无效的原因是什么?我想了解更多。 有见地的回答。谢谢!【参考方案2】:拟合此非线性模型的点太少,因此拟合对种子很敏感。一颗好种子会有所帮助,但它不是先验的。您还可以添加更多数据点。
通过迭代各种种子,我确定random_state=9
可以正常工作。当然还有其他人。
from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0.0, 1, 0.01).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()
nn = MLPRegressor(
hidden_layer_sizes=(10,), activation='relu', solver='adam', alpha=0.001, batch_size='auto',
learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True,
random_state=9, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True,
early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
n = nn.fit(x, y)
test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1)
test_y = nn.predict(test_x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=1, c='b', marker="s", label='real')
ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction')
plt.show()
这里是种子整数 i = 0..9
的拟合绝对误差:
print(i, sum(abs(test_y - np.sin(2 * np.pi * test_x).ravel())))
产生:
0 13.0874999193
1 7.2879574143
2 6.81003360188
3 5.73859777885
4 12.7245375367
5 7.43361211586
6 7.04137436733
7 7.42966661997
8 7.35516939164
9 2.87247035261
现在,即使使用random_state=0
,我们仍然可以通过将目标点的数量从 100 增加到 1000 并将隐藏层的大小从 10 增加到 100 来改进拟合:
from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0.0, 1, 0.001).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()
nn = MLPRegressor(
hidden_layer_sizes=(100,), activation='relu', solver='adam', alpha=0.001, batch_size='auto',
learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True,
random_state=0, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True,
early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
n = nn.fit(x, y)
test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1)
test_y = nn.predict(test_x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=1, c='b', marker="s", label='real')
ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction')
plt.show()
产量:
顺便说一句,您的MLPRegressor()
中的某些参数是不必要的,例如momentum
、nesterovs_momentum
等。请查看文档。此外,它有助于播种您的示例以确保结果是可重现的;)
【讨论】:
以上是关于使用 MLPRegressor 拟合简单数据时遇到问题的主要内容,如果未能解决你的问题,请参考以下文章
Scikit-learn MLPRegressor - 如何不预测负面结果?