使用 jax 数组索引到 numpy 数组:错误的错误消息

Posted

技术标签:

【中文标题】使用 jax 数组索引到 numpy 数组:错误的错误消息【英文标题】:indexing into numpy array with jax array: faulty error messages 【发布时间】:2021-08-16 20:17:43 【问题描述】:

以下 numpy 代码非常好:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

迁移到 jax 后也可以使用:

import jax.numpy as jnp

arr = jnp.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

现在让我们尝试混合使用 numpy 和 jax:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

这会产生以下错误:

IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed

如果不支持使用 jax 数组对 numpy 数组进行索引,那对我来说很好。但是错误信息似乎是错误的。事情变得更加混乱。如果稍微改变形状,代码就可以正常工作。在下面的示例中,我只编辑了从 (30,) 到 (40,) 的索引形状。没有更多错误消息:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)

arr[indices]

我在 cpu 上运行 jax 版本“0.2.12”。 这里发生了什么?

【问题讨论】:

看起来像是将indices 视为一个元组 - 如果小于 32,则为最大维数。旧代码对一些列表进行了此操作,但较新的版本正在努力弃用该行为。 【参考方案1】:

这是一个长期存在的已知问题(请参阅https://github.com/google/jax/issues/620);这不是 JAX 本身可以轻松修复的错误,但需要更改 NumPy 处理非ndarray 索引的方式。好消息是修复即将到来:上面有问题的代码伴随着以下警告,该警告来自 NumPy:

FutureWarning: Using a non-tuple sequence for multidimensional indexing is
 deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this
 will be interpreted as an array index, `arr[np.array(seq)]`, which will result
 either in an error or a different result.

此弃用周期完成后,JAX 数组将在 NumPy 索引中正常工作。

在此之前,您可以在使用 JAX 数组索引 NumPy 数组时显式调用 np.asarray 来解决此问题。

【讨论】:

以上是关于使用 jax 数组索引到 numpy 数组:错误的错误消息的主要内容,如果未能解决你的问题,请参考以下文章

numpy多维索引:使用np数组并列出不同的结果

NumPy 数组切片索引

如何在 numpy 数组上定义一个使用数组索引查找字典的函数?

用jax计算行向(或轴向)点积的最佳方法是什么?

NumPy数组IndexError:索引99超出轴0的大小为1的范围

使用 numpy 数组有效地索引 numpy 数组