用 scikit-learn 拟合一维数据来预测线

Posted

技术标签:

【中文标题】用 scikit-learn 拟合一维数据来预测线【英文标题】:Fit one-dimensional data with scikit-learn to predict line 【发布时间】:2017-08-13 15:50:00 【问题描述】:

我用 scikit-learn 编写了代码来为一维玩具数据构建 SVR 预测模型,然后用 matplotlib 绘制它。

蓝线是真实数据。具有线性内核的模型符合一条不错的线,但对于 2 级内核,预测不是我所期望的。我想要一个模型来预测蓝线的值略低于橙色线的预测值。我画了一条黑线来形象化我的想法。

    为什么会这样?数据似乎是 2 次多项式的一个很好的候选者。黑色趋势线跟随真实数据,然后在右边很晚地减少,如果我只看这个,应该比绿线提供的拟合更好阴谋。不应该根据数据找到具有 2 次多项式的模型吗?它也会在靠近蓝线的 X = 0 处很好地弯曲,而不是在该处具有更高估计 y 值的曲率。

    如何实现我想要的模型?

下面的代码是完整且独立的,运行它得到上面的图(减去画黑线)

# some data
y = [0, 3642, 6414, 9844, 13210, 16072, 18868, 22275, 25551, 28949, 31680, 34412, 37290, 39858, 42557, 
    45094, 47354, 49547, 51874, 54534, 55987, 55987, 58377, 60767, 63109, 65060, 66865, 68540, 70328, 
    72035, 73905, 75791, 77873, 79791, 81775, 83726]
X = range(0, len(y))
X_longer = range(0, len(y)*2)

# train models
from sklearn.svm import SVR
import numpy as np
clf_1 = SVR(kernel='poly', C=1e3, degree=1)
clf_2 = SVR(kernel='poly', C=1e3, degree=2)

clf_1.fit(np.array(X).reshape(-1, 1), y)
clf_2.fit(np.array(X).reshape(-1, 1), y)

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

# plot real data
plt.plot(X, y, linewidth=8.0, label='true data')

predicted_1_y = []
predicted_2_y = []

# predict data points based on models
for i in X_longer:
    predicted_1_y.append(clf_1.predict(np.array([i]).reshape(-1, 1)))
    predicted_2_y.append(clf_2.predict(np.array([i]).reshape(-1, 1)))

# plot model predictions
plt.plot(X_longer, predicted_1_y, linewidth=6.0, ls=":", label='model, degree 1')
plt.plot(X_longer, predicted_2_y, linewidth=6.0, ls=":", label='model, degree 2')

plt.legend(loc='upper left')
plt.show()

【问题讨论】:

【参考方案1】:

发生这种情况是因为线性和二次特征最终总是会向上或向下增长。您需要像平方根或对数这样的运算来获取所需的衰减特征。

一种简单的方法是在拟合之前转换输入信号。例如,假设一个平方根趋势:

X = np.array(X)[:,None]**2
clf = SVR(kernel='linear').fit(X, y) 

对于更一般的用例,如果你真的不知道你想要的趋势,或者不想假设这样的特定转换,你可以尝试像 Eureqa 这样的回归工具来计算最佳转换和数学模型可能。

【讨论】:

只是说它最终总会上升或下降并不能真正解释任何事情。他想要的趋势线可能是一个二次趋势,直到向右一定距离才会下降。 @BrenBarn 我想这取决于他最终想要什么。如果没有像这样的更严格的约束,很难控制模型如何推断或超出数据的范围。 @BrenBarn 写的也是我的假设。上面的代码没有运行,TypeError: list indices must be integers, not tuple.

以上是关于用 scikit-learn 拟合一维数据来预测线的主要内容,如果未能解决你的问题,请参考以下文章

Scikit-learn KMeans 聚类 - 用 X 特征拟合集群,用 X-1 特征预测集群成员?

逻辑回归

机器学习算法的随机数据生成

机器学习算法的随机数据生成

用 scikit-learn 拟合向量自回归模型

用 Python 拟合和预测数据库中每一行的线性回归