JIT Jax 中的最小二乘损失函数

Posted

技术标签:

【中文标题】JIT Jax 中的最小二乘损失函数【英文标题】:JIT a least squares loss function in Jax 【发布时间】:2021-12-04 04:54:05 【问题描述】:

我有一个看起来像这样的简单损失函数

        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

我想优化参数r 并使用一些静态参数xy 来计算残差。所有有问题的参数都是DeviceArrays

为了 JIT 这个,我尝试了以下操作

        @partial(jax.jit, static_argnums=(1, 2))
        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

但我收到此错误

jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.

我从#6233 了解到这是设计使然,但我想知道这里的解决方法是什么,因为这似乎是一个非常常见的用例,您有一些固定的(输入、输出)训练数据对和一些免费的变量。

感谢任何提示!

编辑:这是我尝试使用 jax.jit 时遇到的错误

jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`

【问题讨论】:

【参考方案1】:

听起来您将静态参数视为“在计算之间不变的值”。在 JAX 的 JIT 中,静态参数可以更好地被认为是“可散列的编译时常量”。在您的情况下,您没有可散列的编译时常量;你有数组,所以你可以在没有静态参数的情况下进行 JIT 编译:

@jit
def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))

如果您真的希望 JAX 机器知道您的数组是常量,您可以通过闭包或部分传递它们来做到这一点;例如:

from functools import partial

def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))

但是,对于您正在执行的计算类型,其中常量是由 JAX 数组函数操作的数组,这两种方法导致基本相同的降低 XLA 代码,因此您不妨使用更简单的一种。

【讨论】:

感谢您的详细回答。但是,当我尝试仅使用 @jit 运行它时出现错误。 (在 OP 中以获得更好的可读性)。在堆栈跟踪之后,f 似乎有问题,但我想知道是否可以了解更多关于这种错误意味着什么的背景信息。按照它提供的链接,我会看到静态参数注释。 其实我会单独发一篇文章,因为它与 OP 无关。感谢您的帮助!

以上是关于JIT Jax 中的最小二乘损失函数的主要内容,如果未能解决你的问题,请参考以下文章

LSSVM回归预测基于matlab人工蜂群算法优化最小二乘支持向量机LSSVM数据回归预测含Matlab源码 2213期

LSSVM回归预测基于matlab天鹰算法优化最小二乘支持向量机AO-LSSVM数据回归预测含Matlab源码 1848期

LSSVM回归预测基于matlab天鹰算法优化最小二乘支持向量机AO-LSSVM数据回归预测含Matlab源码 1848期

最小二乘回归,岭回归,Lasso回归,弹性网络

支持向量机之最小二乘(LS)-------6

犰狳函数的不同最小二乘误差