使用 Python JAX/Autograd 的向量值函数的雅可比行列式

Posted

技术标签:

【中文标题】使用 Python JAX/Autograd 的向量值函数的雅可比行列式【英文标题】:Jacobian determinant of vector-valued function with Python JAX/Autograd 【发布时间】:2020-05-01 08:57:40 【问题描述】:

我有一个将向量映射到向量的函数

我想计算它的Jacobian determinant

,

雅可比定义为

因为我可以使用numpy.linalg.det 来计算行列式,所以我只需要雅可比矩阵。我知道numdifftools.Jacobian,但这使用数值微分,我在自动微分之后。输入Autograd/JAX(我现在会坚持使用Autograd,它具有autograd.jacobian() 方法,但我很乐意使用JAX,只要我得到我想要的)。 如何正确使用 autograd.jacobian()-function 和向量值函数?

作为一个简单的例子,我们来看看函数

![f(x)=(x_0^2, x_1^2)](https://chart.googleapis.com/chart?cht=tx&chl=f(x%29%20%3D%20(x_0%5E2%2C%20x_1%5E2%29)

有雅可比行列

![J_f = diag(2 x_0, 2 x_1)](https://chart.googleapis.com/chart?cht=tx&chl=J_f%20%3D%20%5Coperatorname%7Bdiag%7D(2x_0%2C%202x_1%29)

导致雅可比行列式

>>> import autograd.numpy as np
>>> import autograd as ag
>>> x = np.array([[3],[11]])
>>> result = 4*x[0]*x[1]
array([132])
>>> jac = ag.jacobian(f)(x)
array([[[[ 6],
         [ 0]]],


       [[[ 0],
         [22]]]])
>>> jac.shape
(2, 1, 2, 1)
>>> np.linalg.det(jac)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python3.8/site-packages/autograd/tracer.py", line 48, in f_wrapped
    return f_raw(*args, **kwargs)
  File "<__array_function__ internals>", line 5, in det
  File "/usr/lib/python3.8/site-packages/numpy/linalg/linalg.py", line 2113, in det
    _assert_stacked_square(a)
  File "/usr/lib/python3.8/site-packages/numpy/linalg/linalg.py", line 213, in _assert_stacked_square
    raise LinAlgError('Last 2 dimensions of the array must be square')
numpy.linalg.LinAlgError: Last 2 dimensions of the array must be square

第一种方法给了我正确的值,但形状错误。为什么.jacobian() 会返回这样一个嵌套数组?如果我正确地重塑它,我会得到正确的结果:

>>> jac = ag.jacobian(f)(x).reshape(-1,2,2)
array([[[ 6,  0],
        [ 0, 22]]])
>>> np.linalg.det(jac)
array([132.])

但是现在让我们看看这如何与数组广播一起工作,当我尝试评估 x 的多个值的雅​​可比行列式时

>>> x = np.array([[3,5,7],[11,13,17]])
array([[ 3,  5,  7],
       [11, 13, 17]])
>>> result = 4*x[0]*x[1]
array([132, 260, 476])
>>> jac = ag.jacobian(f)(x)
array([[[[ 6,  0,  0],
         [ 0,  0,  0]],

        [[ 0, 10,  0],
         [ 0,  0,  0]],

        [[ 0,  0, 14],
         [ 0,  0,  0]]],


       [[[ 0,  0,  0],
         [22,  0,  0]],

        [[ 0,  0,  0],
         [ 0, 26,  0]],

        [[ 0,  0,  0],
         [ 0,  0, 34]]]])
>>> jac = ag.jacobian(f)(x).reshape(-1,2,2)
>>> jac
array([[[ 6,  0],
        [ 0,  0]],

       [[ 0,  0],
        [ 0, 10]],

       [[ 0,  0],
        [ 0,  0]],

       [[ 0,  0],
        [14,  0]],

       [[ 0,  0],
        [ 0,  0]],

       [[ 0, 22],
        [ 0,  0]],

       [[ 0,  0],
        [ 0,  0]],

       [[26,  0],
        [ 0,  0]],

       [[ 0,  0],
        [ 0, 34]]])
>>> jac.shape
(9,2,2)

显然这两个形状都是错误的,正确的(如 我想要的雅可比矩阵)woule be

[[[ 6,  0],
  [ 0, 22]],
 [[10,  0],
  [ 0, 26]],
 [[14,  0],
  [ 0, 34]]]

shape=(6,2,2)

我需要如何使用autograd.jacobian(或jax.jacfwd/jax.jacrev)才能使其正确处理多个向量输入?


注意:使用显式循环并手动处理每个点,我得到了正确的结果。但是有没有办法做到这一点?

>>> dets = []
>>> for v in zip(*x):
>>>    v = np.array(v)
>>>    jac = ag.jacobian(f)(v)
>>>    print(jac, jac.shape, '\n')
>>>    det = np.linalg.det(jac)
>>>    dets.append(det)
 [[ 6.  0.]
 [ 0. 22.]] (2, 2)

 [[10.  0.]
 [ 0. 26.]] (2, 2)

 [[14.  0.]
 [ 0. 34.]] (2, 2)

>>> dets
 [131.99999999999997, 260.00000000000017, 475.9999999999998]

【问题讨论】:

【参考方案1】:

“如何正确使用这个 autograd.jacobian() 函数和向量值函数?”

你已经写了

x = np.array([[3],[11]])

这有两个问题。首先是这是一个向量的向量,而 autograd 是为向量到向量函数而设计的。第二个是 autograd 需要浮点数,而不是整数。如果你试图区分整数,你会得到一个错误。您没有看到向量向量的错误,因为 autograd 会自动将您的整数列表转换为浮点数列表。

TypeError: Can't differentiate w.r.t. type <class 'int'>

下面的代码应该给你决定因素。

import autograd.numpy as np
import autograd as ag

def f(x):
    return np.array([x[0]**2,x[1]**2])

x = np.array([3.,11.])
jac = ag.jacobian(f)(x)
result = np.linalg.det(jac)
print(result)

“我需要如何使用 autograd.jacobian(或 jax.jacfwd/jax.jacrev)才能正确处理多个向量输入?”

有一种方法可以做到这一点,它被称为 jax.vmap。请参阅 JAX 文档。 (https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)

在这种情况下,我可以使用以下代码计算雅可比行列式向量。请注意,我可以用与以前完全相同的方式定义函数 f,vmap 在幕后为我们完成工作。

import jax.numpy as np
import jax

def f(x):
    return np.array([x[0]**2,x[1]**2])

x = np.array([[3.,11.],[5.,13.],[7.,17.]])

jac = jax.jacobian(f)
vmap_jac = jax.vmap(jac)
result = np.linalg.det(vmap_jac(x))
print(result)

【讨论】:

以上是关于使用 Python JAX/Autograd 的向量值函数的雅可比行列式的主要内容,如果未能解决你的问题,请参考以下文章

matlab类中的向量化

使用 iOS 4.0 库时的向后兼容性

如何处理 UIPageViewController 中的向后滚动?

GCC:两个相似循环之间的向量化差异

sklearn:文本分类交叉验证中的向量化

我们是不是需要 C++ 中的向量化或 for 循环已经足够快?