PRML 1.1 多项式曲线拟合

Posted Real&Love

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PRML 1.1 多项式曲线拟合相关的知识,希望对你有一定的参考价值。

PRML 1.1 多项式曲线拟合


  • 输入 训练集

    x ≡ ( x 1 , . . . , x N ) T x\\equiv (x_1,...,x_N)^T x(x1,...,xN)T
    t ≡ ( t 1 , . . . , t N ) T t\\equiv (t_1,...,t_N)^T t(t1,...,tN)T

  • 输出 拟合曲线

1.1.1 代码

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(1)
X = np.linspace(0, 1, 10)
y = np.sin(2*np.pi*X) + np.random.normal(0.1, 0.1, 10)  # 加入噪声
X_true = np.linspace(0, 1, 256, endpoint=True)
y_true = np.sin(2*np.pi*X_true)

plt.scatter(X, y, c="b", alpha=0.6)
plt.plot(X_true, y_true, c="g")
plt.show()

在这里插入图片描述

可以看到绿色的是潜在的待发现的函数 sin ⁡ ( 2 π x ) \\sin(2\\pi x) sin(2πx),也就是我们最终想预测到对拟合曲线,但是现在根据输入【10个点的数据集】来进行拟合的。

1.1.2 多项式推导

我们需要用一个公式来拟合这些点,假设这是一个关于x的多项式

y ( x , ω ) = ω 0 + ω 1 x + ω 2 x 2 + . . . + ω M x M = ∑ j = 0 M ω j x j y(x,\\omega)=\\omega_0+\\omega_1x+\\omega_2x^2+...+\\omega_Mx^M=\\sum_{j=0}^{M}{\\omega_jx^j} y(x,ω)=ω0+ω1x+ω2x2+...+ωMxM=j=0Mωjxj

  • 上面公式中 M M M表示多项式的阶数
  • M = 1 M=1 M=1时,为简单的线性回归方程

当𝑀=0M=0或𝑀=1M=1时,拟合曲线如下图上部分的红线所示

我们肉眼可以看到,拟合效果是非常差的。我们怎么量化这种训练时的误差呢?故引出下面常见的一种度量方法。
每个数据点的预测值 y ( x n , ω ) y(x_n,\\omega) y(xn,ω)和真实值 t n t_n tn之间的平方和,这个 E ( ω ) E(\\omega) E(ω)很明显越小越好
E ( ω ) = 1 2 ∑ n = 1 N [ y ( x n , ω ) − t n ] 2 E(\\omega)=\\frac{1}{2}\\sum_{n=1}^{N}{[y(x_n,\\omega)-t_n]^2} E(ω)=21n=1N[y(xn,ω)tn]2
在这里插入图片描述

1.1.3 过拟合

解决过拟合的方法较多,如调节模型参数数量,让模型变得简单;加入正则项惩罚模型的复杂度,增加训练集的个数

在这里插入图片描述

1.1.4 正则化

我们给误差函数增加一个惩罚项。如下所示:
E ( ω ) = 1 2 ∑ n = 1 N [ y ( x n , ω ) − t n ] 2 + λ 2 ∥ ω ∥ 2 {E}(\\omega)=\\frac{1}{2}\\sum_{n=1}^{N}{[y(x_n,\\omega)-t_n]^2}+\\frac{\\lambda }{2} \\left \\| \\omega \\right \\|^2 E(ω)=21n=1N[y(xn,ω)tn]2+2λω2

其中
∥ ω ∥ 2 = ω T ω = ω 0 2 + ω 1 2 + . . . + ω M 2 \\left \\| \\omega \\right \\|^2=\\omega^T\\omega=\\omega_{0}^{2}+\\omega_{1}^{2}+ ...+\\omega_{M}^{2} ω2=ωTω=ω02+ω12+...+ωM2

  • 统计学 : 收缩法(shrinkage)【xgboost也会用到此方法】
  • 神经网络: 权值衰减(weight decay)

以上是关于PRML 1.1 多项式曲线拟合的主要内容,如果未能解决你的问题,请参考以下文章

PRML 学习: Polynomial Curve Fitting

PRML 学习: Polynomial Curve Fitting

简单的PRML阅读笔记

python实现logistic增长模型

C++曲线拟合代码

MATLAB点云处理(十七):最小二乘多项式曲线拟合