使用 for 循环时如何减少 JAX 编译时间?
Posted
技术标签:
【中文标题】使用 for 循环时如何减少 JAX 编译时间?【英文标题】:How to reduce JAX compile time when using for loop? 【发布时间】:2021-11-03 07:45:08 【问题描述】:这是一个基本示例。
@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
当 cons 较小时,编译时间约为一分钟。使用更大的缺点,编译时间要长得多——10 分钟。我需要更高的缺点。可以做什么? 从我正在阅读的内容来看,循环是原因。它们在编译时展开。 有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。
我对这一切都很陌生。因此,感谢所有帮助。 如果您能提供一些如何使用 jax 循环的示例,将不胜感激。
另外,什么是好的编译时间?几分钟内就可以了吗? 在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。
运行时的任何收益都会被编译时间所掩盖。
【问题讨论】:
【参考方案1】:我不确定这是否与numba
相同,但这可能是类似的情况。
当我使用numba.jit
编译器并且有大数据输入时,我首先在一些小的示例数据上编译函数,然后使用它。
伪代码:
func_being_compiled(small_amount_of_data) # compile-only purpose
func_being_compiled(large_amount_of_data)
【讨论】:
【参考方案2】:JAX 的 JIT 编译器会扁平化所有 Python 循环。要了解我的意思,请看一下通过jax.make_jaxpr
运行的这个简单函数,它是一种检查 JAX 的跟踪器如何解释 python 代码的方法(请参阅Understanding Jaxprs 了解更多信息):
import jax
def f(x):
for i in range(5):
x += i
return x
print(jax.make_jaxpr(f)(0))
# lambda ; a.
# let b = add a 0
# c = add b 1
# d = add c 2
# e = add d 3
# f = add e 4
# in (f,)
请注意,循环是扁平的:每一步都成为发送到 XLA 编译器的显式操作。 XLA 编译时间会随着函数中操作数量的增加而增加,因此三重嵌套的 for 循环会导致编译时间变长是有道理的。
那么,如何解决这个问题?好吧,不幸的是,答案取决于你的 --do something--
正在做什么,所以我猜不出来。
一般来说,最好的选择是使用向量化数组操作,而不是循环这些向量中的值;例如,这是一种添加两个向量的非常慢的方法:
import jax.numpy as jnp
def f_slow(x, y):
z = []
for xi, yi in zip(xi, yi):
z.append(xi + yi)
return jnp.array(z)
这是一种更快的方法来做同样的事情:
def f_fast(x, y):
return x + y
如果您的操作不适合矢量化,另一种选择是使用 lax control flow 运算符代替 for
循环:这会将循环向下推入 XLA。这在 CPU 上可以有相当好的性能,但与等效的向量化数组操作相比,在加速器上的速度较慢。
有关 JAX 和 Python 控制流语句(如for
、if
、while
等)的更多讨论,请参阅? JAX - The Sharp Bits ?: Control Flow。
【讨论】:
对于无法向量化的操作,jax.lax.fori_loop
与 python for
循环相比显着减少了编译时间。而且,确实,它不需要减少计算时间。以上是关于使用 for 循环时如何减少 JAX 编译时间?的主要内容,如果未能解决你的问题,请参考以下文章