Jax 矢量化:vmap 和/或 numpy.vectorize?

Posted

技术标签:

【中文标题】Jax 矢量化:vmap 和/或 numpy.vectorize?【英文标题】:Jax vectorization: vmap and/or numpy.vectorize? 【发布时间】:2021-11-05 00:53:48 【问题描述】:

jax.numpy.vectorizejax.vmap有什么区别? 这是一个小片段

import jax
import jax.numpy as jnp

def f(x):
     return jnp.exp(-x)*jnp.sin(x)

gf = jax.grad(f)
x = jnp.arange(0,1,0.1)

jax.vmap(gf)(x)
jnp.vectorize(gf)(x)

两种计算都给出相同的结果:

DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923], dtype=float32)

如何决定使用哪一个,以及在性能方面是否存在差异?

【问题讨论】:

【参考方案1】:

jax.vmapjax.numpy.vectorize 具有完全不同的语义,并且仅在您的示例中的单个 1D 输入的情况下恰好相似。

jax.vmap 的目的是将函数映射到沿单个显式轴的一个或多个输入上,如 in_axes 参数所指定。另一方面,jax.numpy.vectorize 根据 numpy 广播规则沿零个或多个隐式轴将函数映射到一个或多个输入。

要查看差异,让我们传递两个二维输入并在函数内打印形状:

import jax
import jax.numpy as jnp

def print_shape(x, y):
  print(f"x.shape = x.shape")
  print(f"y.shape = y.shape")
  return x + y

x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))

_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)

_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()

请注意,vmap 仅沿第一个轴映射,而 vectorize 沿两个输入轴映射。

还要注意vectorize 的隐式映射意味着它可以更灵活地使用;例如:

x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)

def add(x, y):
  # vectorize always maps over all axes, such that the function is applied elementwise
  assert x.shape == y.shape == ()
  return x + y

jnp.vectorize(add)(x2, y2).shape
# (20, 10)

vectorize 将根据 numpy 广播规则遍历输入的所有轴。另一方面,vmap 默认无法处理此问题:

jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20

vmap 完成同样的操作需要更多的思考,因为有两个独立的映射轴,并且一些轴是广播的。但是你可以通过这种方式完成同样的事情:

jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)

后者嵌套的vmap 本质上是当您使用jax.numpy.vectorize 时发生的事情。

至于在任何特定情况下使用哪个:

如果您想将函数映射到单个、明确指定的输入轴,请使用jax.vmap 如果您希望根据应用于输入的 numpy 广播规则将函数的输入映射到零个或多个轴上,请使用 jax.numpy.vectorize。 在变换相同的情况下(例如在映射一维输入的函数时)倾向于使用vmap,因为它更直接地执行您想做的事情。

【讨论】:

谢谢@jakevdp 你的意思是对于 nD 输入 (n>1) 在某些情况下矢量化更好吗?还是应该使用 vmap? 不是更好,而是不同。如果您希望根据 numpy 广播规则广播您的输入,请使用矢量化。如果您希望您的函数映射到输入的单个特定轴,请使用 vmap。 我编辑了答案以添加更多上下文 - 希望对您有所帮助! 喂!非常感谢@jakevdp,您丰富的答案非常有启发性。

以上是关于Jax 矢量化:vmap 和/或 numpy.vectorize?的主要内容,如果未能解决你的问题,请参考以下文章

如何将 jax vmap 用于嵌套循环?

使用vmap(jax)对矩阵元素求和?

使用 vmap 时,Jax 不支持不可散列的静态参数

向量化嵌套 vmap

不同长度的 JAX 批处理

DS.DELMIA.VMAP.V5-6R2017.SP2.Win32