向量化嵌套 vmap

Posted

技术标签:

【中文标题】向量化嵌套 vmap【英文标题】:Vectorise nested vmap 【发布时间】:2021-12-19 02:42:21 【问题描述】:

这是我掌握的一些数据:

import jax.numpy as jnp
import numpyro.distributions as dist
import jax

xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)

我想运行函数

def func(x, y):
    return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))

在来自xaxisyaxis 的每一对值上。

这是一种“缓慢”的做法:

results = np.zeros((len(xaxis), len(yaxis)))

for i in range(len(xaxis)):
    for j in range(len(yaxis)):
        results[i, j] = func(xaxis[i], yaxis[j])

有效,但速度很慢。

所以这是一种矢量化的方法:

jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)

快得多,但很难阅读。

有没有一种简洁的方式来编写矢量化版本?我可以用一个vmap 来完成,而不必将一个嵌套在另一个中吗?

编辑

另一种方法是

jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T

但还是很乱。

【问题讨论】:

【参考方案1】:

我相信Vectorization guidelnes for jax 与您的问题非常相似;使用 vmap 复制嵌套 for 循环的逻辑需要嵌套的 vmap。

使用jax.vmap 的最简洁的方法可能是这样的:

from functools import partial

@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
    return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))

func(xaxis, yaxis)

这里的另一个选择是使用jnp.vectorize API(通过多个 vmap 实现),在这种情况下,您可以执行以下操作:

print(jnp.vectorize(func)(xaxis[:, None], yaxis))

【讨论】:

以上是关于向量化嵌套 vmap的主要内容,如果未能解决你的问题,请参考以下文章

如何将 jax vmap 用于嵌套循环?

如何向量化嵌套循环

向量化嵌套循环,其中一个循环变量依赖于另一个

用于简单数组更新的 Jax vmap

向量化代码并从 pytorch 代码中删除嵌套循环

在python中为依赖于索引的函数向量化嵌套的for循环