如何使用 CuPy 在 python 上进行三次样条插值?

Posted

技术标签:

【中文标题】如何使用 CuPy 在 python 上进行三次样条插值?【英文标题】:How to do a cubic spline interpolation on python with CuPy? 【发布时间】:2020-03-12 23:32:44 【问题描述】:

我正在编写代码,使用 GPU 多次进行三次样条插值。 我知道如何在 numpy 上使用

scipy.interpolate.splrep

scipy.interpolate.interp1d(kind='cubic')

interp1d 是我现在用于 numpy 数组的。但我需要在 CuPy 上运行它们。

但是我应该如何在 CuPy 上做到这一点?我有一个 x 值和 y 值。而且我还有一个带有新 x 值的数组。我现在正在编写的代码将为新的 x 值计算新的 y 值。

【问题讨论】:

【参考方案1】:

我先从纯python的源代码scipy.interpolate.CubicSpline开始,把np.something一一替换成cupy.something

【讨论】:

但是cupy是基于cuda的。 cuda中有三次样条插值。我只想在cupy中找到那个功能。但我不能。我试过谷歌它,但找不到任何东西。我可以做你的建议,但如果有一些简单的方法来处理它,我只是措辞。 我试过这个。可悲的是,它还具有 scipy 固有的许多不同功能,例如solve_banded。我放弃。如果我只想重写整个cubicSpline,这是一个很大的项目,因为我还需要重写scipy中的其他函数。 Solve_banded 具体很简单,主要是lapack调用,直接到cuda lapack 我认为我没有很好地解释它......solve_banded 只是一个例子......还有其他部分我需要改变,比如solve()。但是当我去解决和solve_banded的文件时,他们会导入更多的代码。然后我去他们导入的代码,更多导入的函数。然后我需要修改大量的代码行,这非常耗时。 抱歉再次打扰您。我正在尝试重新编写它。但这里有一个问题:CubicSpline 是 PPoly 的子类,对于 PPoly,它使用来自 cython 代码 _ppoly 的函数 _ppoly.evaluate。我是否应该简单地将代码 _ppoly 上的所有 numpy 更改为 cupy?【参考方案2】:

我现在可以自己回答这个问题了。我完成了插值算法。作为ev-br的回答的建议,我只是为cupy重写了scipy.interpolate.CubicSpline

scipy.interpolate.CubicSpline 包含很多函数,如果我们只需要一个插值函数,这将无济于事。

CubicSpline有一个父类scipy.interpolate.PPoly,如果你只想要一个插值函数,它也包含了很多不必要的函数。清理干净后,我只使用了 _PPolyBasesolve_banded()prepare_input() 类。

最难的部分是用 cython 编写的函数evaluate()。 Cython 不支持 cupy,所以我使用支持 cuda 的 numba 来加速循环的速度。

函数evaluate()的头部应该是这样的:

@cuda.jit('void(complex128[:,:,:], float64[:], float64[:], complex128[:,:])')
def evaluate(c, x, xp, out):

需要注意的重要一点是评估函数不是线程安全函数。

只有evaluate() 中的第一个循环是:

for ip in range(len(xp)):
    xval = xp[IP]
    ......

可以使用cuda.grid(1)cuda.gridsize(1)

另外,我将evaluate_poly1()find_interval_descending() 结合在evaluate() 中,以更好地适应numbe 的cuda 支持。

速度超级快,比原来的scipy函数快3到4倍左右。

代码可以在这里找到:https://github.com/GavinJiacheng/Interpolation_CUPY

【讨论】:

干得好!如果我理解正确,对splrep 执行相同操作似乎更复杂,因为splrep 中与evaluate 对应的函数是编译后的dfitpack.curfit。 (这会引发错误“将 dfitpack.curfit 的参数转换为 C/Fortran 数组失败).. .. 嗯... @niCkcAMel 抱歉回复晚了。但我认为你需要对许多不应该运行数学算法而是一些标志检查的函数做更多的清理和削减。他们中的许多人需要在 sippy 中调用一些检查函数,这是完全没用的。 不用担心。做得好!我克隆了你的 git 代码并试了一下。不幸的是,考虑到我的“通常”样本点数(~30000),它并没有带来任何改进(与未修改的 scipy-CubicSpline 相比)。我在大约 400 万个样本点上进行了权衡。

以上是关于如何使用 CuPy 在 python 上进行三次样条插值?的主要内容,如果未能解决你的问题,请参考以下文章

使用scipy.interpolate.CubicSpline添加或乘以三次样条曲线

C ++中的三次样条插值

三次样条插值介绍

三次样条插值法

三次样条插值

三次样条插值