使用 JAX 和 SciPy 对不正确积分进行微分

Posted

技术标签:

【中文标题】使用 JAX 和 SciPy 对不正确积分进行微分【英文标题】:Differentiation of an improper integral using JAX and SciPy 【发布时间】:2022-01-21 06:51:20 【问题描述】:

我提供了一个简单的代码示例,说明使用 SciPyquad() 方法通过不正确的积分函数自动区分使用 JAX 失败的尝试。我考虑的功能是

渐变由

以下代码能够计算函数,但是当我尝试计算梯度时,JAX 会抛出 ConcretizationTypeError 错误。 这里有什么问题以及如何解决?

import jax
from scipy.integrate import quad

## Function
def F(c1, c2):
    val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
    return val

## Gradient
grad_F = jax.grad(F)

## Parameters
c1 = -1.0
c2 = 0.0

## Evaluates function
F(c1, c2)
# -1.5707963267948966   (which is -pi/2 btw)

## Evaluates gradient
grad_F(c1, c2)

投掷:

---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
/tmp/ipykernel_446012/1229440296.py in <module>
----> 1 grad_F(c1, c2)

    [... skipping hidden 9 frame]

/tmp/ipykernel_446012/2999885932.py in F(c1, c2)
      5 def F(c1, c2):
      6     #val, err = jax.numpy.array(quad(lambda y: b/(1.0+y**2), a=a, b=jax.numpy.inf), float)
----> 7     val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
      8     return val
      9 

~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
    349 
    350     if weight is None:
--> 351         retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
    352                        points)
    353     else:

~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
    463             return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
    464         else:
--> 465             return _quadpack._qagie(func,bound,infbounds,args,full_output,epsabs,epsrel,limit)
    466     else:
    467         if infbounds != 0:

    [... skipping hidden 1 frame]

~/anaconda3/lib/python3.8/site-packages/jax/core.py in error(self, arg)
    998                       f"or `jnp.array(x, fun.__name__)` instead.")
    999   def error(self, arg):
-> 1000     raise ConcretizationTypeError(arg, fname_context)
   1001   return error
   1002 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ConcreteArray(-0.5, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.5, dtype=float32, weak_type=True)
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), *)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fbf4f643b90>, invars=(Traced<ConcreteArray(2.0, dtype=float32):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=1/0)>), outvars=[<weakref at 0x7fbf4c402c20; to 'JaxprTracer' at 0x7fbf4eca5090>], primitive=xla_call, params='device': None, 'backend': None, 'name': 'jvp(true_divide)', 'donated_invars': (False, False), 'inline': True, 'call_jaxpr':  lambda ; a:f32[] b:f32[]. let c:f32[] = div b a in (c,) , source_info=<jaxlib.xla_extension.Traceback object at 0x7fbf4eca1bb0>)
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

【问题讨论】:

【参考方案1】:

问题在于 JAX 的 grad 转换只能对完全由 JAX 操作组成的函数进行操作,而 scipy.integrate.quad 不是 JAX 操作。如果你想做这种计算,你必须找到quad 的 JAX 实现。

【讨论】:

我有点期待这个,但是解决这个问题的最简单方法是什么?你能提供一个解决方案吗?这是我的第二部分问题。是否有有效的替代方案来集成 jax?

以上是关于使用 JAX 和 SciPy 对不正确积分进行微分的主要内容,如果未能解决你的问题,请参考以下文章

Scipy---5.数值积分

有没有办法用 scipy.fft 在傅里叶空间中进行数值积分?

SciPy 科学计算基础

结合 scipy.root 和 Jax Jacobian

机器学习基础 | Scipy 简易入门

Python------SciPy模块