是否可以 jit 使用 jax.numpy.unique 的函数?

Posted

技术标签:

【中文标题】是否可以 jit 使用 jax.numpy.unique 的函数?【英文标题】:is it possible to jit a function which uses jax.numpy.unique? 【发布时间】:2021-08-16 18:04:52 【问题描述】:

以下代码不起作用:

def get_unique(arr):
    return jnp.unique(arr)

get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))

错误信息涉及jnp.unique的使用:

FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()

documentation on sharp bits 解释说,如果内部数组的形状取决于参数值,则 jit 不起作用。这正是这里的情况。

根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。几乎每个函数调用的参数都会改变。我已将我的代码拆分为一个预处理步骤,该步骤执行诸如 jnp.unique 之类的计算,以及一个可以 jitted 的计算步骤。

但我还是想问一下,是否有一些我不知道的解决方法?

【问题讨论】:

【参考方案1】:

不,由于您提到的原因,目前无法在非静态值上使用 jnp.unique

在类似的情况下,JAX 有时会添加额外的参数,这些参数可用于指定输出的静态大小(例如,jax.numpy.nonzero 中的size 参数),但目前jnp.unique 没有实现类似的功能。如果这是您想要的,值得提交feature request。

【讨论】:

以上是关于是否可以 jit 使用 jax.numpy.unique 的函数?的主要内容,如果未能解决你的问题,请参考以下文章

如何检测是不是启用了 PHP JIT

Lambda初次使用很慢?从JIT到类加载再到实现原理

有没有办法关闭 JIT 编译器,这样做会影响性能吗?

对 Linux 内核中的 eBPF JIT 漏洞进行 fuzz

你需要了解的JIT Debugging

Java JIT 是不是曾经优化递归方法调用?