为啥 scipy.optimize.curve_fit 不适合数据?

Posted

技术标签:

【中文标题】为啥 scipy.optimize.curve_fit 不适合数据?【英文标题】:Why does scipy.optimize.curve_fit not fit to the data?为什么 scipy.optimize.curve_fit 不适合数据? 【发布时间】:2013-03-15 11:32:45 【问题描述】:

一段时间以来,我一直在尝试使用 scipy.optimize.curve_fit 对某些数据进行指数拟合,但我遇到了真正的困难。我真的看不出这不起作用的任何原因,但它只会产生一条直线,不知道为什么!

任何帮助将不胜感激

from __future__ import division
import numpy
from scipy.optimize import curve_fit
import matplotlib.pyplot as pyplot

def func(x,a,b,c):
   return a*numpy.exp(-b*x)-c


yData = numpy.load('yData.npy')
xData = numpy.load('xData.npy')

trialX = numpy.linspace(xData[0],xData[-1],1000)

# Fit a polynomial 
fitted = numpy.polyfit(xData, yData, 10)[::-1]
y = numpy.zeros(len(trailX))
for i in range(len(fitted)):
   y += fitted[i]*trialX**i

# Fit an exponential
popt, pcov = curve_fit(func, xData, yData)
yEXP = func(trialX, *popt)

pyplot.figure()
pyplot.plot(xData, yData, label='Data', marker='o')
pyplot.plot(trialX, yEXP, 'r-',ls='--', label="Exp Fit")
pyplot.plot(trialX,   y, label = '10 Deg Poly')
pyplot.legend()
pyplot.show()

xData = [1e-06, 2e-06, 3e-06, 4e-06,
5e-06, 6e-06, 7e-06, 8e-06,
9e-06, 1e-05, 2e-05, 3e-05,
4e-05, 5e-05, 6e-05, 7e-05,
8e-05, 9e-05, 0.0001, 0.0002,
0.0003, 0.0004, 0.0005, 0.0006,
0.0007, 0.0008, 0.0009, 0.001,
0.002, 0.003, 0.004, 0.005,
0.006, 0.007, 0.008, 0.009, 0.01]

yData = 
[6.37420666067e-09, 1.13082012115e-08,
1.52835756975e-08, 2.19214493931e-08, 2.71258852882e-08, 3.38556130078e-08, 3.55765277358e-08,
4.13818145846e-08, 4.72543475372e-08, 4.85834751151e-08, 9.53876562077e-08, 1.45110636413e-07,
1.83066627931e-07, 2.10138415308e-07, 2.43503982686e-07, 2.72107045549e-07, 3.02911771395e-07,
3.26499455951e-07, 3.48319349445e-07, 5.13187669283e-07, 5.98480176303e-07, 6.57028222701e-07,
6.98347073045e-07, 7.28699930335e-07, 7.50686502279e-07, 7.7015576866e-07, 7.87147246927e-07,
7.99607141001e-07, 8.61398763228e-07, 8.84272900407e-07, 8.96463883243e-07, 9.04105135329e-07,
9.08443443149e-07, 9.12391264185e-07, 9.150842683e-07, 9.16878548643e-07, 9.18389990067e-07]

【问题讨论】:

我在尝试运行您的代码时遇到多个错误 - 首先,trialX 拼写错误,然后我收到 operands could not be broadcast together with shapes 错误。你确定这是你的确切代码吗? @DavidRobinson:要处理操作数问题,请确保xDatayData 都是ndarrays。 【参考方案1】:

模型a*exp(-b*x)+c 与数据非常吻合,但我建议稍作修改: 改用这个

a*x*exp(-b*x)+c

祝你好运

【讨论】:

【参考方案2】:

在不考虑数据先验知识的情况下,对此解决方案的(轻微)改进可能如下:取数据集的反均值并将其用作“比例因子”以传递给底层由 curve_fit() 调用的 minimumsq()。这允许拟合器工作并返回数据原始比例的参数。

相关行是:

popt, pcov = curve_fit(func, xData, yData)

变成:

popt, pcov = curve_fit(func, xData, yData,
    diag=(1./xData.mean(),1./yData.mean()) )

这是生成此图像的完整示例:

from __future__ import division
import numpy
from scipy.optimize import curve_fit
import matplotlib.pyplot as pyplot

def func(x,a,b,c):
   return a*numpy.exp(-b*x)-c


xData = numpy.array([1e-06, 2e-06, 3e-06, 4e-06, 5e-06, 6e-06,
7e-06, 8e-06, 9e-06, 1e-05, 2e-05, 3e-05, 4e-05, 5e-05, 6e-05,
7e-05, 8e-05, 9e-05, 0.0001, 0.0002, 0.0003, 0.0004, 0.0005,
0.0006, 0.0007, 0.0008, 0.0009, 0.001, 0.002, 0.003, 0.004, 0.005
, 0.006, 0.007, 0.008, 0.009, 0.01])

yData = numpy.array([6.37420666067e-09, 1.13082012115e-08,
1.52835756975e-08, 2.19214493931e-08, 2.71258852882e-08,
3.38556130078e-08, 3.55765277358e-08, 4.13818145846e-08,
4.72543475372e-08, 4.85834751151e-08, 9.53876562077e-08,
1.45110636413e-07, 1.83066627931e-07, 2.10138415308e-07,
2.43503982686e-07, 2.72107045549e-07, 3.02911771395e-07,
3.26499455951e-07, 3.48319349445e-07, 5.13187669283e-07,
5.98480176303e-07, 6.57028222701e-07, 6.98347073045e-07,
7.28699930335e-07, 7.50686502279e-07, 7.7015576866e-07,
7.87147246927e-07, 7.99607141001e-07, 8.61398763228e-07,
8.84272900407e-07, 8.96463883243e-07, 9.04105135329e-07,
9.08443443149e-07, 9.12391264185e-07, 9.150842683e-07,
9.16878548643e-07, 9.18389990067e-07])

trialX = numpy.linspace(xData[0],xData[-1],1000)

# Fit a polynomial
fitted = numpy.polyfit(xData, yData, 10)[::-1]
y = numpy.zeros(len(trialX))
for i in range(len(fitted)):
   y += fitted[i]*trialX**i

# Fit an exponential
popt, pcov = curve_fit(func, xData, yData,
    diag=(1./xData.mean(),1./yData.mean()) )
yEXP = func(trialX, *popt)

pyplot.figure()
pyplot.plot(xData, yData, label='Data', marker='o')
pyplot.plot(trialX, yEXP, 'r-',ls='--', label="Exp Fit")
pyplot.plot(trialX,   y, label = '10 Deg Poly')
pyplot.legend()
pyplot.show()

【讨论】:

对答案的补充非常好!在进行交互式分析时,先验知识几乎总是可用的,但自动设置并非总是如此。【参考方案3】:

数值算法在不输入极小(或大)数字时往往效果更好。

在这种情况下,图表显示您的数据具有极小的 x 和 y 值。如果对它们进行缩放,则拟合效果会更好:

xData = np.load('xData.npy')*10**5
yData = np.load('yData.npy')*10**5

from __future__ import division

import os
os.chdir(os.path.expanduser('~/tmp'))

import numpy as np
import scipy.optimize as optimize
import matplotlib.pyplot as plt

def func(x,a,b,c):
   return a*np.exp(-b*x)-c


xData = np.load('xData.npy')*10**5
yData = np.load('yData.npy')*10**5

print(xData.min(), xData.max())
print(yData.min(), yData.max())

trialX = np.linspace(xData[0], xData[-1], 1000)

# Fit a polynomial 
fitted = np.polyfit(xData, yData, 10)[::-1]
y = np.zeros(len(trialX))
for i in range(len(fitted)):
   y += fitted[i]*trialX**i

# Fit an exponential
popt, pcov = optimize.curve_fit(func, xData, yData)
print(popt)
yEXP = func(trialX, *popt)

plt.figure()
plt.plot(xData, yData, label='Data', marker='o')
plt.plot(trialX, yEXP, 'r-',ls='--', label="Exp Fit")
plt.plot(trialX, y, label = '10 Deg Poly')
plt.legend()
plt.show()

注意xDatayData重新缩放后,curve_fit返回的参数也必须重新缩放。在这种情况下,abc 必须分别除以 10**5 才能获得原始数据的拟合参数。


您可能对上述内容的一个反对意见是,必须“谨慎”地选择缩放比例。 (阅读:并非所有合理的比例选择都有效!)

您可以通过为参数提供合理的初始猜测来提高curve_fit 的稳健性。通常您对数据有一些先验知识,这些知识可以激发对合理参数值的粗略/信封类型猜测。

例如,用

调用curve_fit
guess = (-1, 0.1, 0)
popt, pcov = optimize.curve_fit(func, xData, yData, guess)

有助于提高curve_fit 在这种情况下成功的范围。

【讨论】:

这样好多了!它不喜欢小数字有什么原因吗? 我对@9​​87654335@ 算法的研究还不够仔细,无法告诉你确切的原因。但总的来说,这些算法需要测试参数值的猜测,然后调整猜测。如果数据的幅度在 1 左右,初始调整的大小可能会很好,但如果数据的幅度在 10**-6 左右,则可能完全超出正确答案。 @unutbu 最初的猜测大约是 1 是对的。来自 docs.scipy.org/doc/scipy/reference/generated/… p0 : None, scalar, or M-length sequence Initial guess for the parameters. If None, then the initial values will all be 1 (if the number of parameters for the function can be determined using introspection, otherwise a ValueError is raised). 哪里 scipy.optimize.curve_fit(f, xdata, ydata, p0=None, sigma=None, **kw)[source]

以上是关于为啥 scipy.optimize.curve_fit 不适合数据?的主要内容,如果未能解决你的问题,请参考以下文章

为啥 DataGridView 上的 DoubleBuffered 属性默认为 false,为啥它受到保护?

为啥需要softmax函数?为啥不简单归一化?

为啥 g++ 需要 libstdc++.a?为啥不是默认值?

为啥或为啥不在 C++ 中使用 memset? [关闭]

为啥临时变量需要更改数组元素以及为啥需要在最后取消设置?

为啥 CAP 定理中的 RDBMS 分区不能容忍,为啥它可用?