使用 for 循环时如何减少 JAX 编译时间?

Posted

技术标签:

【中文标题】使用 for 循环时如何减少 JAX 编译时间?【英文标题】:How to reduce JAX compile time when using for loop? 【发布时间】:2021-11-03 07:45:08 【问题描述】:

这是一个基本示例。

@jax.jit
def block(arg1, arg2):
   for x1 in range(cons1):
       for x2 in range(cons2):
          for x3 in range(cons3):
             --do something--
   return result

当 cons 较小时,编译时间约为一分钟。使用更大的缺点,编译时间要长得多——10 分钟。我需要更高的缺点。可以做什么? 从我正在阅读的内容来看,循环是原因。它们在编译时展开。 有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。

我对这一切都很陌生。因此,感谢所有帮助。 如果您能提供一些如何使用 jax 循环的示例,将不胜感激。

另外,什么是好的编译时间?几分钟内就可以了吗? 在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。

运行时的任何收益都会被编译时间所掩盖。

【问题讨论】:

【参考方案1】:

我不确定这是否与numba 相同,但这可能是类似的情况。

当我使用numba.jit编译器并且有大数据输入时,我首先在一些小的示例数据上编译函数,然后使用它。

伪代码:

func_being_compiled(small_amount_of_data)  # compile-only purpose
func_being_compiled(large_amount_of_data)

【讨论】:

【参考方案2】:

JAX 的 JIT 编译器会扁平化所有 Python 循环。要了解我的意思,请看一下通过jax.make_jaxpr 运行的这个简单函数,它是一种检查 JAX 的跟踪器如何解释 python 代码的方法(请参阅Understanding Jaxprs 了解更多信息):

import jax

def f(x):
  for i in range(5):
    x += i
  return x

print(jax.make_jaxpr(f)(0))
#  lambda  ; a.
#   let b = add a 0
#       c = add b 1
#       d = add c 2
#       e = add d 3
#       f = add e 4
#   in (f,) 

请注意,循环是扁平的:每一步都成为发送到 XLA 编译器的显式操作。 XLA 编译时间会随着函数中操作数量的增加而增加,因此三重嵌套的 for 循环会导致编译时间变长是有道理的。

那么,如何解决这个问题?好吧,不幸的是,答案取决于你的 --do something-- 正在做什么,所以我猜不出来。

一般来说,最好的选择是使用向量化数组操作,而不是循环这些向量中的值;例如,这是一种添加两个向量的非常慢的方法:

import jax.numpy as jnp

def f_slow(x, y):
  z = []
  for xi, yi in zip(xi, yi):
    z.append(xi + yi)
  return jnp.array(z)

这是一种更快的方法来做同样的事情:

def f_fast(x, y):
  return x + y

如果您的操作不适合矢量化,另一种选择是使用 lax control flow 运算符代替 for 循环:这会将循环向下推入 XLA。这在 CPU 上可以有相当好的性能,但与等效的向量化数组操作相比,在加速器上的速度较慢。

有关 JAX 和 Python 控制流语句(如forifwhile 等)的更多讨论,请参阅? JAX - The Sharp Bits ?: Control Flow。

【讨论】:

对于无法向量化的操作,jax.lax.fori_loop 与 python for 循环相比显着减少了编译时间。而且,确实,它不需要减少计算时间。

以上是关于使用 for 循环时如何减少 JAX 编译时间?的主要内容,如果未能解决你的问题,请参考以下文章

使用条件 while/for 循环的指针在编译时会出错

如何将 jax vmap 用于嵌套循环?

编译时内的模板参数展开for循环?

减少咖啡脚本中的循环

在 Python 中减少 for 循环是不可能的?

51单片机for循环(计算1加到10)问题?