不同长度的 JAX 批处理
Posted
技术标签:
【中文标题】不同长度的 JAX 批处理【英文标题】:JAX batching with different lengths 【发布时间】:2021-09-19 00:11:09 【问题描述】:我有一个函数compute(x)
,其中x
是jnp.ndarray
。现在,我想使用vmap
将其转换为一个函数,该函数接受一批数组x[i]
,然后jit
对其进行加速。 compute(x)
类似于:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
但是,每个数组x[i]
具有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度 N
和 vmap(compute)
可以应用于形状为 (batch_size, N)
的批次。
但是,这样做会导致在每个数组 x[i]
的尾随零上也调用 very_expensive_function()
。有没有办法修改compute()
,使得very_expensive_function()
只在x
的一部分上调用,而不干扰vmap
和jit
?
【问题讨论】:
显而易见的解决方案是将每个 x[i] 的实际长度也传递给计算,然后对该 x[i] 进行切片,但这可能不受 jax 支持。看看这个:github.com/google/jax/issues/1007。也许传递一个面具是你可以做的。 this 回答有用吗? 【参考方案1】:使用 JAX,当您想要 jit 函数以加快速度时,给定的批处理参数 x
必须是定义良好的 ndarray(即 x[i] 必须具有相同的形状)。无论您是否使用vmap
,这都是正确的。
现在,通常的处理方法是填充这些数组。这意味着您在参数中添加掩码,以便填充值不会影响您的结果。例如,如果我想计算形状为(bath_size, max_length)
的填充值x
的softmax
,我需要“禁用”填充值的效果。这是一个例子:
import jax.numpy as jnp
import jax
PAD = 0
MINUS_INFINITY = -1e6
x = jnp.array([
[1, 2, 3, 4],
[1, 2, PAD, PAD],
[1, 2, 3, PAD]
])
mask = jnp.array([
[1, 1, 1, 1],
[1, 1, 0, 0],
[1, 1, 1, 0]
])
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)
它不像填充x
那么简单。您需要在每一步实际更改计算以禁用填充效果。在 softmax 的情况下,您可以通过将填充值设置为接近负无穷来做到这一点。
最后,您无法真正提前知道使用或不使用 padding + masking 的速度性能是否会更好。根据我的经验,它通常会导致 CPU 的良好改进,以及 GPU 的非常大的改进。特别是,批处理大小的选择对性能有很大的影响,因为更高的batch_size
将在统计上导致更高的max_length
,因此对填充值执行更多的“无用”计算。
【讨论】:
以上是关于不同长度的 JAX 批处理的主要内容,如果未能解决你的问题,请参考以下文章
IBM JAX-RS 1.1 处理 Dojo Ajax OPTIONS 请求