用于简单数组更新的 Jax vmap

Posted

技术标签:

【中文标题】用于简单数组更新的 Jax vmap【英文标题】:Jax vmap for simple array update 【发布时间】:2021-08-11 18:53:24 【问题描述】:

我是 Jax 的新手,我正在努力转换其他人的代码,该代码使用 numba “fastmath” 功能并依赖于许多嵌套的 for 循环,而不会造成太大的性能损失。我正在尝试使用 Jax 的 vmap 函数重新创建相同的行为。但是,我目前在一些基本问题上遇到了很多困难。这是我尝试使用 vmap 进行矢量化的简化示例:

import jax.numpy as jnp
from jax import vmap
import jax.ops

a = jnp.arange(20).reshape((4, 5))
b = jnp.arange(5)
c = jnp.arange(4)
d = jnp.zeros(20)
e = jnp.zeros((4, 5))

for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        a = jax.ops.index_add(a, jax.ops.index[i, j], b[j] + c[i])
        d = jax.ops.index_update(d, jax.ops.index[i*a.shape[1] + j], b[j] * c[i])
        e = jax.ops.index_update(e, jax.ops.index[i, j], 2*b[j])

如何使用 vmap 重写这样的代码?虽然此代码相对容易手动矢量化,但我希望更好地了解 vmap 的工作原理,并希望任何答案都能帮助我。这些文档现在似乎并没有真正帮助我。非常感谢您提供的任何帮助。

【问题讨论】:

【参考方案1】:

以下是使用vmap 实现大致相同计算的方法:

from jax import vmap, partial

@partial(vmap, in_axes=(0, None, 0))
@partial(vmap, in_axes=(0, 0, None))
def f(a, b, c):
  return a + b + c, b * c, 2 * b

a, d, e = f(a, b, c)
d = d.ravel()

【讨论】:

非常感谢!这个例子真的帮助了我学习。史诗。

以上是关于用于简单数组更新的 Jax vmap的主要内容,如果未能解决你的问题,请参考以下文章

Jax 矢量化:vmap 和/或 numpy.vectorize?

使用vmap(jax)对矩阵元素求和?

使用 vmap 时,Jax 不支持不可散列的静态参数

不同长度的 JAX 批处理

webService的简单使用

JAX-RS 不适用于 Spring Boot 1.4.1