如何将 jax vmap 用于嵌套循环?

Posted

技术标签:

【中文标题】如何将 jax vmap 用于嵌套循环?【英文标题】:How to use jax vmap for nested loops? 【发布时间】:2021-11-24 12:23:42 【问题描述】:

我想使用 vmap 对这段代码进行矢量化以提高性能。

def matrix(dataA, dataB):
    return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)

我试过这个:

def f(x, y):
    return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)

但这只会给出对角线条目。

基本上我有一个向量data = [1,2,3,4,5](示例),我想得到一个矩阵,使得矩阵的每个条目(i, j)f(data[i], data[j])。因此,生成的矩阵形状将是(len(data), len(data))

【问题讨论】:

【参考方案1】:

jax.vmap 一次映射一组轴。如果要跨两组独立的轴进行映射,可以通过嵌套两个 vmap 转换来实现:

mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)

【讨论】:

以上是关于如何将 jax vmap 用于嵌套循环?的主要内容,如果未能解决你的问题,请参考以下文章

如何将两个for循环嵌套使用,要求内层循环结束,外层也一起结束。

丰富的数据表不能将迭代变量用于嵌套循环

有两个循环,他们是嵌套关系,在内循环中使用break将终止哪一个循环?

Tkinter 嵌套主循环

嵌套for循环慢python用于计算特殊标准偏差

batfor循环嵌套改名