JAX:避免对沿一个轴使用不同数量的元素评估的函数进行即时重新编译

Posted

技术标签:

【中文标题】JAX:避免对沿一个轴使用不同数量的元素评估的函数进行即时重新编译【英文标题】:JAX: avoid just-in-time recompilation for a function evaluated with a varying number of elements along one axis 【发布时间】:2022-01-04 15:10:02 【问题描述】:

除了一个轴具有不同数量的元素之外,当输入的结构基本保持不变时,是否可以避免重新编译 JIT 函数?

import jax

@jax.jit
def f(x):
    print('recompiling')
    return (x + 10) * 100

a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't

要求:pip install jax, jaxlib

【问题讨论】:

根据github.com/google/jax/issues/803,目前这似乎是不可能的。 XLA 编译器需要已知的形状。 【参考方案1】:

不,当您调用具有不同形状的数组的函数时,无法避免重新编译。从根本上说,JAX 为静态形状的输入和输出编译函数,并且使用新形状的数组调用 JIT 编译的函数将始终触发重新编译。

目前正在进行一些放宽此要求的工作(在 JAX 的 github 存储库中搜索“动态形状”),但目前没有此类 API 可用。

【讨论】:

以上是关于JAX:避免对沿一个轴使用不同数量的元素评估的函数进行即时重新编译的主要内容,如果未能解决你的问题,请参考以下文章

如何将 std::vector 的容量限制为元素的数量

评估具有不同长度的多个二进制测试答案的正确统计测试/ R 函数是啥?

如何使用 Diesel 计算数组列中不同元素的数量?

如何避免短路评估

Seaborn/Plotly 多个 y 轴

pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?