使用vmap(jax)对矩阵元素求和?

Posted

技术标签:

【中文标题】使用vmap(jax)对矩阵元素求和?【英文标题】:sum matrix elementwise using vmap (jax)? 【发布时间】:2021-09-20 18:25:20 【问题描述】:

我正在尝试了解 vmap 中的 in_axes 和 out_axes 选项。 例如,我想对两个矩阵求和并得到相同形状的输出。

X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
    return x + y
vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)

我想我分别为 X 和 Y 映射了轴 0 和 1。输出将具有与 X,Y 相同的形状。 但我得到了错误,

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-403-103694166574> in <module>
      3 def sum2(x,y):
      4     return x + y
----> 5 vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)

    [... skipping hidden 2 frame]

~/anaconda3/lib/python3.8/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
    276       assert treedef_is_leaf(leaf)
    277       axis_tree, _ = axis_tree
--> 278     raise ValueError(f"name specification must be a tree prefix of the "
    279                      f"corresponding value, got specification axis_tree "
    280                      f"for value tree treedef.") from None

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ((0, 1), (0, 1)) for value tree PyTreeDef((*, *)).

【问题讨论】:

【参考方案1】:

首先,按元素求和最简单的方法是使用二进制运算的内置广播,并直接调用sum2(X, Y)

也就是说,如果您想了解vmap:问题是vmap 一次只能映射一个轴。如果要映射多个轴,可以嵌套多个 vmap。我相信你的意图可以这样表达:

from jax import vmap
import jax.numpy as np

X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)

def sum2(x,y):
    assert x.ndim == y.ndim == 0
    return x + y

vmap(vmap(sum
  vmap(sum2, in_axes=(0, 0), out_axes=0),
  in_axes=(1, 1), out_axes=1
)(X,Y)

注意:我添加了关于维数的断言,以证明映射函数是在标量值上调用的。

另外,请注意,当映射轴匹配时,例如in_axes=(0, 0) 可以等效地写成 in_axes=0,但我将其保留为元组,因为它更接近您尝试的语法。

事实上,使用嵌套 vmap 进行相同计算的更简洁的方法是使用默认参数:vmap(vmap(sum2))(X, Y) 将执行相同的元素求和。

【讨论】:

谢谢!我映射了第一个轴并将输出重新整形。

以上是关于使用vmap(jax)对矩阵元素求和?的主要内容,如果未能解决你的问题,请参考以下文章

Clickhouse - 矩阵逐项加法:如何对二维数组求和?

用于简单数组更新的 Jax vmap

python数组求和

如何将 jax vmap 用于嵌套循环?

使用 vmap 时,Jax 不支持不可散列的静态参数

Jax 矢量化:vmap 和/或 numpy.vectorize?