使用 JAX 和 SciPy 对不正确积分进行微分
Posted
技术标签:
【中文标题】使用 JAX 和 SciPy 对不正确积分进行微分【英文标题】:Differentiation of an improper integral using JAX and SciPy 【发布时间】:2022-01-21 06:51:20 【问题描述】:我提供了一个简单的代码示例,说明使用 SciPy
的 quad()
方法通过不正确的积分函数自动区分使用 JAX
失败的尝试。我考虑的功能是
渐变由
以下代码能够计算函数,但是当我尝试计算梯度时,JAX
会抛出 ConcretizationTypeError
错误。 这里有什么问题以及如何解决?
import jax
from scipy.integrate import quad
## Function
def F(c1, c2):
val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
return val
## Gradient
grad_F = jax.grad(F)
## Parameters
c1 = -1.0
c2 = 0.0
## Evaluates function
F(c1, c2)
# -1.5707963267948966 (which is -pi/2 btw)
## Evaluates gradient
grad_F(c1, c2)
投掷:
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
/tmp/ipykernel_446012/1229440296.py in <module>
----> 1 grad_F(c1, c2)
[... skipping hidden 9 frame]
/tmp/ipykernel_446012/2999885932.py in F(c1, c2)
5 def F(c1, c2):
6 #val, err = jax.numpy.array(quad(lambda y: b/(1.0+y**2), a=a, b=jax.numpy.inf), float)
----> 7 val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
8 return val
9
~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
349
350 if weight is None:
--> 351 retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
352 points)
353 else:
~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
463 return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
464 else:
--> 465 return _quadpack._qagie(func,bound,infbounds,args,full_output,epsabs,epsrel,limit)
466 else:
467 if infbounds != 0:
[... skipping hidden 1 frame]
~/anaconda3/lib/python3.8/site-packages/jax/core.py in error(self, arg)
998 f"or `jnp.array(x, fun.__name__)` instead.")
999 def error(self, arg):
-> 1000 raise ConcretizationTypeError(arg, fname_context)
1001 return error
1002
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ConcreteArray(-0.5, dtype=float32)>with<JVPTrace(level=2/0)> with
primal = DeviceArray(-0.5, dtype=float32, weak_type=True)
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), *)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fbf4f643b90>, invars=(Traced<ConcreteArray(2.0, dtype=float32):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=1/0)>), outvars=[<weakref at 0x7fbf4c402c20; to 'JaxprTracer' at 0x7fbf4eca5090>], primitive=xla_call, params='device': None, 'backend': None, 'name': 'jvp(true_divide)', 'donated_invars': (False, False), 'inline': True, 'call_jaxpr': lambda ; a:f32[] b:f32[]. let c:f32[] = div b a in (c,) , source_info=<jaxlib.xla_extension.Traceback object at 0x7fbf4eca1bb0>)
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
【问题讨论】:
【参考方案1】:问题在于 JAX 的 grad
转换只能对完全由 JAX 操作组成的函数进行操作,而 scipy.integrate.quad
不是 JAX 操作。如果你想做这种计算,你必须找到quad
的 JAX 实现。
【讨论】:
我有点期待这个,但是解决这个问题的最简单方法是什么?你能提供一个解决方案吗?这是我的第二部分问题。是否有有效的替代方案来集成 jax?以上是关于使用 JAX 和 SciPy 对不正确积分进行微分的主要内容,如果未能解决你的问题,请参考以下文章