不同长度的 JAX 批处理

Posted

技术标签:

【中文标题】不同长度的 JAX 批处理【英文标题】:JAX batching with different lengths 【发布时间】:2021-09-19 00:11:09 【问题描述】:

我有一个函数compute(x),其中xjnp.ndarray。现在,我想使用vmap 将其转换为一个函数,该函数接受一批数组x[i],然后jit 对其进行加速。 compute(x) 类似于:

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

但是,每个数组x[i] 具有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度 Nvmap(compute) 可以应用于形状为 (batch_size, N) 的批次。

但是,这样做会导致在每个数组 x[i] 的尾随零上也调用 very_expensive_function()。有没有办法修改compute(),使得very_expensive_function() 只在x 的一部分上调用,而不干扰vmapjit

【问题讨论】:

显而易见的解决方案是将每个 x[i] 的实际长度也传递给计算,然后对该 x[i] 进行切片,但这可能不受 jax 支持。看看这个:github.com/google/jax/issues/1007。也许传递一个面具是你可以做的。 this 回答有用吗? 【参考方案1】:

使用 JAX,当您想要 jit 函数以加快速度时,给定的批处理参数 x 必须是定义良好的 ndarray(即 x[i] 必须具有相同的形状)。无论您是否使用vmap,这都是正确的。

现在,通常的处理方法是填充这些数组。这意味着您在参数中添加掩码,以便填充值不会影响您的结果。例如,如果我想计算形状为(bath_size, max_length) 的填充值xsoftmax,我需要“禁用”填充值的效果。这是一个例子:

import jax.numpy as jnp
import jax

PAD = 0
MINUS_INFINITY = -1e6

x = jnp.array([ 
       [1, 2, 3, 4],
       [1, 2, PAD, PAD],
       [1, 2, 3, PAD]
    ])

mask = jnp.array([
           [1, 1, 1, 1],
           [1, 1, 0, 0],
           [1, 1, 1, 0]
       ])
       
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)    

它不像填充x那么简单。您需要在每一步实际更改计算以禁用填充效果。在 softmax 的情况下,您可以通过将填充值设置为接近负无穷来做到这一点。

最后,您无法真正提前知道使用或不使用 padding + masking 的速度性能是否会更好。根据我的经验,它通常会导致 CPU 的良好改进,以及 GPU 的非常大的改进。特别是,批处理大小的选择对性能有很大的影响,因为更高的batch_size 将在统计上导致更高的max_length,因此对填充值执行更多的“无用”计算。

【讨论】:

以上是关于不同长度的 JAX 批处理的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 JAX-RS 和 Jersey 处理 CORS

如何使用JAX-RS和Jersey处理CORS

在 JAX-RS 2.0 客户端库中处理自定义错误响应

IBM JAX-RS 1.1 处理 Dojo Ajax OPTIONS 请求

如何使用批处理运行 RNN 模型,其中每行具有不同长度的文本?

JAX-RS入门 三 :细节