为两个不规则网格之间的多个插值加速 scipy griddata

Posted

技术标签:

【中文标题】为两个不规则网格之间的多个插值加速 scipy griddata【英文标题】:Speedup scipy griddata for multiple interpolations between two irregular grids 【发布时间】:2014-01-21 19:44:28 【问题描述】:

我在同一个不规则网格(x, y, z) 上定义了几个值,我想将它们插值到新网格(x1, y1, z1) 上。即,我有f(x, y, z), g(x, y, z), h(x, y, z),我想计算f(x1, y1, z1), g(x1, y1, z1), h(x1, y1, z1)

目前我正在使用scipy.interpolate.griddata 执行此操作,并且效果很好。然而,因为我必须单独执行每个插值并且有很多点,所以它很慢,计算中存在大量重复(即找到最接近的点,设置网格等......)。

有没有办法加快计算速度,减少重复计算?即类似于定义两个网格,然后更改插值的值?

【问题讨论】:

您使用的是什么插值方法,即nearestlinear...?另外,你的不规则网格中有多少点? 我正在使用线性插值(最近的还不够好)。原始网格 (x,y,z) 由 350 万个点组成。新网格 (x1,y1,z1) 由大约 300,000 个点组成。在配备 i7 处理器和大量 RAM 的笔记本电脑上,线性插值需要大约 30 秒。我有 6 组值要插值,所以这对我来说是一个主要瓶颈。 【参考方案1】:

每次拨打scipy.interpolate.griddata时都会发生几件事情:

    首先,调用sp.spatial.qhull.Delaunay 对不规则网格坐标进行三角测量。 然后,对于新网格中的每个点,搜索三角剖分以找到它位于哪个三角形中(实际上,在哪个单纯形中,在您的 3D 情况下是哪个四面体)。 计算每个新网格点相对于封闭单纯形顶点的重心坐标。 使用重心坐标和封闭单纯形顶点处的函数值计算该网格点的插值。

前三个步骤对于所有插值都是相同的,因此如果您可以为每个新网格点存储封闭单纯形的顶点索引和插值的权重,则可以通过以下方式最小化计算量很多。不幸的是,这并不容易直接使用可用的功能来完成,尽管这确实是可能的:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools

def interp_weights(xyz, uvw):
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

函数interp_weights 执行我上面列出的前三个步骤的计算。然后函数 interpolate 使用这些计算值非常快地执行第 4 步:

m, n, d = 3.5e4, 3e3, 3
# make sure no new grid point is extrapolated
bounding_cube = np.array(list(itertools.product([0, 1], repeat=d)))
xyz = np.vstack((bounding_cube,
                 np.random.rand(m - len(bounding_cube), d)))
f = np.random.rand(m)
g = np.random.rand(m)
uvw = np.random.rand(n, d)

In [2]: vtx, wts = interp_weights(xyz, uvw)

In [3]: np.allclose(interpolate(f, vtx, wts), spint.griddata(xyz, f, uvw))
Out[3]: True

In [4]: %timeit spint.griddata(xyz, f, uvw)
1 loops, best of 3: 2.81 s per loop

In [5]: %timeit interp_weights(xyz, uvw)
1 loops, best of 3: 2.79 s per loop

In [6]: %timeit interpolate(f, vtx, wts)
10000 loops, best of 3: 66.4 us per loop

In [7]: %timeit interpolate(g, vtx, wts)
10000 loops, best of 3: 67 us per loop

首先,它与griddata 的作用相同,这很好。其次,设置插值,即计算vtxwts 与调用griddata 大致相同。但第三,您现在几乎可以立即在同一网格上插入不同的值。

griddata 唯一没有在这里考虑的是将fill_value 分配给必须外推的点。您可以通过检查至少有一个权重为负的点来做到这一点,例如:

def interpolate(values, vtx, wts, fill_value=np.nan):
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts)
    ret[np.any(wts < 0, axis=1)] = fill_value
    return ret

【讨论】:

完美,正是我所追求的!非常感谢。如果在 scipy 中包含此类功能以用于 griddata 的未来版本,那就太好了。 对我来说效果很好!在我的机器上运行多次时,它使用的内存也比 scipy.itnerpolate.griddata 少得多。 另外,griddata 可容纳函数中的缺失值/漏洞 - nan,这不适用于此解决方案? @Jaime 如果我想用额外的点来更新数据,我可以使用tri = qhull.Delaunay(xy, incremental=True) 和更改tri.add_points(xy2) 来加快delaunay 部分,您对如何加快有任何想法find_simplex 只覆盖更新的索引? 如何使用三次插值(griddata 只是一个关键字)?【参考方案2】:

非常感谢 Jaime 的解决方案(即使我不太了解重心计算是如何完成的......)

在这里,您会找到一个改编自他的 2D 案例的示例:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np

def interp_weights(xy, uv,d=2):
    tri = qhull.Delaunay(xy)
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

m, n = 101,201
mi, ni = 1001,2001

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,0]=Y.flatten()
xy[:,1]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
uv[:,0]=Yi.flatten()
uv[:,1]=Xi.flatten()

values=np.cos(2*X)*np.cos(2*Y)

#Computed once and for all !
vtx, wts = interp_weights(xy, uv)
valuesi=interpolate(values.flatten(), vtx, wts)
valuesi=valuesi.reshape(Xi.shape[0],Xi.shape[1])
print "interpolation error: ",np.mean(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  
print "interpolation uncertainty: ",np.std(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  

可以应用图像变换,例如图像映射,并加快 udge 速度

您不能使用相同的函数定义,因为新坐标会在每次迭代中发生变化,但您可以一次性计算三角测量。

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np
import time

# Definition of the fast  interpolation process. May be the Tirangulation process can be removed !!
def interp_tri(xy):
    tri = qhull.Delaunay(xy)
    return tri


def interpolate(values, tri,uv,d=2):
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv- temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)  
    return np.einsum('nj,nj->n', np.take(values, vertices),  np.hstack((bary, 1.0 - bary.sum(axis=1, keepdims=True))))

m, n = 101,201
mi, ni = 101,201

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,1]=Y.flatten()
xy[:,0]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
# creation of a displacement field
uv[:,1]=0.5*Yi.flatten()+0.4
uv[:,0]=1.5*Xi.flatten()-0.7
values=np.zeros_like(X)
values[50:70,90:150]=100.

#Computed once and for all !
tri = interp_tri(xy)
t0=time.time()
for i in range(0,100):
  values_interp_Qhull=interpolate(values.flatten(),tri,uv,2).reshape(Xi.shape[0],Xi.shape[1])
t_q=(time.time()-t0)/100

t0=time.time()
values_interp_griddata=spint.griddata(xy,values.flatten(),uv,fill_value=0).reshape(values.shape[0],values.shape[1])
t_g=time.time()-t0

print "Speed-up:", t_g/t_q
print "Mean error: ",(values_interp_Qhull-values_interp_griddata).mean()
print "Standard deviation: ",(values_interp_Qhull-values_interp_griddata).std()

在我的笔记本电脑上,加速在 20 到 40 倍之间!

希望能帮助到别人

【讨论】:

interp_weights 函数在这里失败,delta = uv - temp[:, d],因为d 超出了temp 的范围【参考方案3】:

我遇到了同样的问题(griddata 非常慢,网格对于许多插值都保持不变),我最喜欢described here 的解决方案,主要是因为它非常易于理解和应用。

它使用LinearNDInterpolator,可以通过只需要计算一次的Delaunay三角剖分。从该帖子复制并粘贴(所有学分归 xdze2):

from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator

tri = Delaunay(mesh1)  # Compute the triangulation

# Perform the interpolation with the given values:
interpolator = LinearNDInterpolator(tri, values_mesh1)
values_mesh2 = interpolator(mesh2)

这将我的计算速度提高了大约 2 倍。

【讨论】:

【参考方案4】:

您可以尝试使用Pandas,因为它提供了高性能的数据结构。

确实,插值方法是 scipy 插值的包装器,但也许通过改进的结构,您可以获得更好的速度。

import pandas as pd;
wp = pd.Panel(randn(2, 5, 4));
wp.interpolate();

interpolate() 使用different methods 填充 Panel 数据集中的 NaN 值。希望它比 Scipy 更快。

如果它不起作用,有一种方法可以提高性能(而不是使用代码的并行版本):使用 Cython 并在 C 中实现小例程以在内部使用你的 Python 代码。 Here 你有一个关于这个的例子。

【讨论】:

以上是关于为两个不规则网格之间的多个插值加速 scipy griddata的主要内容,如果未能解决你的问题,请参考以下文章

在Python / Numpy / Scipy中找到两个数组之间的插值交集

python对不规则(x,y,z)网格进行4D插值

当在两个不同的vbo之间进行插值时,网格面向迷失方向

适当的 numpy/scipy 函数来插入在单纯形(非规则网格)上定义的函数

不能通过scipy.interpolate.griddata对n维网格进行插值

不规则网格上的插值