结合 scipy.root 和 Jax Jacobian

Posted

技术标签:

【中文标题】结合 scipy.root 和 Jax Jacobian【英文标题】:Combine scipy.root and Jax Jacobian 【发布时间】:2022-01-21 08:33:46 【问题描述】:

我在使用来自 scipy.root 的 JAX 雅可比行列时遇到问题。在下面的示例中,root 在没有雅可比的情况下工作,而在雅可比的情况下失败。关于我需要重写什么以使下面的代码与雅可比行列一起使用的任何想法?

from jax import jacfwd
from scipy.optimize import root
import numpy as np

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations): 
  for i in range(len(varNamesExo)):
      exec("%s = %.10f" %(varNamesExo[i], valuesExo[i]))

  for i in range(len(varNamesEndo)):
    exec("%s = %.10f" %(varNamesEndo[i], valuesEndo[i]))
    
  equationVector = np.zeros(len(equations))
  for i in range(len(equations)):
      exec('equationVector[%d] = eval(equations[%d])' %(i, i))    
      
  return equationVector

varNamesEndo = ['x', 'y']
valuesEndoInitialGuess = [1., 1.]

varNamesExo = ['a', 'b']
valuesExo = [1., 1.]

equations = ['a*x+b*y**2-4',
            'np.exp(x) + x*y - 3']

equations = ['a*x**2 + b*y**2',
            'a*x**2 - b*y**2']

# Without Jacobian
sol1 =  root(fun=objectFunction,
            x0=valuesEndoInitialGuess, 
            args=(varNamesEndo, valuesExo, varNamesExo, equations))
#----> Works

# With Jacobian
jac  = jacfwd(objectFunction)
sol2 =  root(fun=objectFunction,
            x0=valuesEndoInitialGuess, 
            args=(varNamesEndo, valuesExo, varNamesExo, equations),
            jac=jac)
#----> Not woring

至少这条线似乎有问题

for i in range(len(varNamesEndo)):
        exec("%s = %.10f" %(varNamesEndo[i], valuesEndo[i]))

【问题讨论】:

【参考方案1】:

有两个问题:

    为了执行自动区分,JAX 依赖于用跟踪器替换值。这意味着您打印和评估值的字符串表示的方法将不起作用。 此外,您正在尝试将跟踪值分配给标准 numpy 数组。您应该改用 JAX 数组,因为它知道如何处理跟踪值。

考虑到这一点,你可以用这种方式重写你的函数,它应该可以工作,只要你的方程只使用 Python 算术运算和 jax 函数(而不是像 np.exp 这样的东西):

import jax.numpy as jnp

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations): 
  for i in range(len(varNamesExo)):
      exec("%s = valuesExo[%d]" %(varNamesExo[i], i))

  for i in range(len(varNamesEndo)):
    exec("%s = valuesEndo[%d]" %(varNamesEndo[i], i))
    
  equationVector = jnp.zeros(len(equations))
  for i in range(len(equations)):
      equationVector = equationVector.at[i].set(eval(equations[i]))
      
  return equationVector

旁注:这种基于使用exec 设置变量名的方法非常脆弱且容易出错;我会建议一种基于建立显式命名空间来评估你的方程的方法。例如这样的:

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations):
  namespace = 
      **dict(zip(varNamesEndo, valuesEndo)),
      **dict(zip(varNamesExo, valuesExo))
  
  return jnp.array([eval(eqn, namespace) for eqn in equations])

【讨论】:

以上是关于结合 scipy.root 和 Jax Jacobian的主要内容,如果未能解决你的问题,请参考以下文章

JAX的深度学习和科学计算

谷歌JAX快速入门笔记详解和案例

谷歌JAX快速入门笔记详解和案例

谷歌JAX快速入门笔记详解和案例

谷歌JAX快速入门笔记详解和案例

要替代TensorFlow?谷歌开源机器学习库JAX