JIT Jax 中的最小二乘损失函数
Posted
技术标签:
【中文标题】JIT Jax 中的最小二乘损失函数【英文标题】:JIT a least squares loss function in Jax 【发布时间】:2021-12-04 04:54:05 【问题描述】:我有一个看起来像这样的简单损失函数
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
我想优化参数r
并使用一些静态参数x
和y
来计算残差。所有有问题的参数都是DeviceArrays
。
为了 JIT 这个,我尝试了以下操作
@partial(jax.jit, static_argnums=(1, 2))
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
但我收到此错误
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.
我从#6233 了解到这是设计使然,但我想知道这里的解决方法是什么,因为这似乎是一个非常常见的用例,您有一些固定的(输入、输出)训练数据对和一些免费的变量。
感谢任何提示!
编辑:这是我尝试使用 jax.jit
时遇到的错误
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`
【问题讨论】:
【参考方案1】:听起来您将静态参数视为“在计算之间不变的值”。在 JAX 的 JIT 中,静态参数可以更好地被认为是“可散列的编译时常量”。在您的情况下,您没有可散列的编译时常量;你有数组,所以你可以在没有静态参数的情况下进行 JIT 编译:
@jit
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
如果您真的希望 JAX 机器知道您的数组是常量,您可以通过闭包或部分传递它们来做到这一点;例如:
from functools import partial
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))
但是,对于您正在执行的计算类型,其中常量是由 JAX 数组函数操作的数组,这两种方法导致基本相同的降低 XLA 代码,因此您不妨使用更简单的一种。
【讨论】:
感谢您的详细回答。但是,当我尝试仅使用@jit
运行它时出现错误。 (在 OP 中以获得更好的可读性)。在堆栈跟踪之后,f
似乎有问题,但我想知道是否可以了解更多关于这种错误意味着什么的背景信息。按照它提供的链接,我会看到静态参数注释。
其实我会单独发一篇文章,因为它与 OP 无关。感谢您的帮助!以上是关于JIT Jax 中的最小二乘损失函数的主要内容,如果未能解决你的问题,请参考以下文章
LSSVM回归预测基于matlab人工蜂群算法优化最小二乘支持向量机LSSVM数据回归预测含Matlab源码 2213期
LSSVM回归预测基于matlab天鹰算法优化最小二乘支持向量机AO-LSSVM数据回归预测含Matlab源码 1848期
LSSVM回归预测基于matlab天鹰算法优化最小二乘支持向量机AO-LSSVM数据回归预测含Matlab源码 1848期