Jax 矢量化:vmap 和/或 numpy.vectorize?
Posted
技术标签:
【中文标题】Jax 矢量化:vmap 和/或 numpy.vectorize?【英文标题】:Jax vectorization: vmap and/or numpy.vectorize? 【发布时间】:2021-11-05 00:53:48 【问题描述】:jax.numpy.vectorize
和jax.vmap
有什么区别?
这是一个小片段
import jax
import jax.numpy as jnp
def f(x):
return jnp.exp(-x)*jnp.sin(x)
gf = jax.grad(f)
x = jnp.arange(0,1,0.1)
jax.vmap(gf)(x)
jnp.vectorize(gf)(x)
两种计算都给出相同的结果:
DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923], dtype=float32)
如何决定使用哪一个,以及在性能方面是否存在差异?
【问题讨论】:
【参考方案1】:jax.vmap
和 jax.numpy.vectorize
具有完全不同的语义,并且仅在您的示例中的单个 1D 输入的情况下恰好相似。
jax.vmap
的目的是将函数映射到沿单个显式轴的一个或多个输入上,如 in_axes
参数所指定。另一方面,jax.numpy.vectorize
根据 numpy 广播规则沿零个或多个隐式轴将函数映射到一个或多个输入。
要查看差异,让我们传递两个二维输入并在函数内打印形状:
import jax
import jax.numpy as jnp
def print_shape(x, y):
print(f"x.shape = x.shape")
print(f"y.shape = y.shape")
return x + y
x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))
_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)
_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()
请注意,vmap
仅沿第一个轴映射,而 vectorize
沿两个输入轴映射。
还要注意vectorize
的隐式映射意味着它可以更灵活地使用;例如:
x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)
def add(x, y):
# vectorize always maps over all axes, such that the function is applied elementwise
assert x.shape == y.shape == ()
return x + y
jnp.vectorize(add)(x2, y2).shape
# (20, 10)
vectorize
将根据 numpy 广播规则遍历输入的所有轴。另一方面,vmap
默认无法处理此问题:
jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20
用vmap
完成同样的操作需要更多的思考,因为有两个独立的映射轴,并且一些轴是广播的。但是你可以通过这种方式完成同样的事情:
jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)
后者嵌套的vmap
本质上是当您使用jax.numpy.vectorize
时发生的事情。
至于在任何特定情况下使用哪个:
如果您想将函数映射到单个、明确指定的输入轴,请使用jax.vmap
如果您希望根据应用于输入的 numpy 广播规则将函数的输入映射到零个或多个轴上,请使用 jax.numpy.vectorize
。
在变换相同的情况下(例如在映射一维输入的函数时)倾向于使用vmap
,因为它更直接地执行您想做的事情。
【讨论】:
谢谢@jakevdp 你的意思是对于 nD 输入 (n>1) 在某些情况下矢量化更好吗?还是应该使用 vmap? 不是更好,而是不同。如果您希望根据 numpy 广播规则广播您的输入,请使用矢量化。如果您希望您的函数映射到输入的单个特定轴,请使用 vmap。 我编辑了答案以添加更多上下文 - 希望对您有所帮助! 喂!非常感谢@jakevdp,您丰富的答案非常有启发性。以上是关于Jax 矢量化:vmap 和/或 numpy.vectorize?的主要内容,如果未能解决你的问题,请参考以下文章