关于基于等级的计算的自动微分
Posted
技术标签:
【中文标题】关于基于等级的计算的自动微分【英文标题】:Automatic Differentiation with respect to rank-based computations 【发布时间】:2022-01-09 19:55:08 【问题描述】:我是自动微分编程的新手,所以这可能是一个幼稚的问题。下面是我要解决的问题的简化版本。
我有两个输入数组——一个大小为N
的向量A
和一个形状为(N, M)
的矩阵B
,以及一个大小为M
的参数向量theta
。我定义了一个新数组C(theta) = B * theta
来获得一个大小为N
的新向量。然后我获取落在C
上下四分位数的元素的索引,并使用它们创建一个新数组A_low(theta) = A[lower quartile indices of C]
和A_high(theta) = A[upper quartile indices of C]
。显然这两个确实依赖于theta
,但是是否可以区分A_low
和A_high
w.r.t theta
?
到目前为止,我的尝试似乎表明没有——我使用了 autograd、JAX 和 tensorflow 的 python 库,但它们都返回零梯度。 (到目前为止我尝试过的方法包括使用 argsort 或使用 tf.top_k
提取相关子数组。)
我正在寻求帮助的是证明导数未定义(或无法分析计算),或者如果它确实存在,请提供有关如何估计它的建议。我的最终目标是最小化某些功能f(A_low, A_high)
wrt theta
。
【问题讨论】:
【参考方案1】:这是我根据您的描述编写的 JAX 计算:
import numpy as np
import jax.numpy as jnp
import jax
N = 10
M = 20
rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))
def f(A, B, theta, k=3):
C = B @ theta
_, i_upper = lax.top_k(C, k)
_, i_lower = lax.top_k(-C, k)
return A[i_lower], A[i_upper]
x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
导数全为零,我相信这是正确的,因为输出值的变化不依赖于theta
的值变化。
但是,您可能会问,这怎么可能?毕竟,theta
进入了计算,如果你为theta
输入不同的值,你会得到不同的输出。梯度怎么可能为零?
不过,您必须记住的是,差异化并不能衡量输入是否影响输出。它测量在输入变化很小的情况下输出的变化。
我们以一个稍微简单的函数为例:
import jax
import jax.numpy as jnp
A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])
def f(A, theta):
return A[jnp.argmax(theta)]
x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)
这里将f
与theta
微分的结果全为零,原因同上。为什么?如果您对theta
进行微小更改,通常不会影响theta
的排序顺序。因此,您从 A
中选择的条目不会因为 theta 的微小变化而改变,因此相对于 theta 的导数为零。
现在,您可能会争辩说,在某些情况下情况并非如此:例如,如果 theta 中的两个值非常接近,那么即使微乎其微地扰动其中一个值,它们各自的等级也可能会发生变化。这是真的,但是这个过程产生的梯度是不确定的(输出的变化相对于输入的变化并不平滑)。好消息是这种不连续性是一方面的:如果你在另一个方向上扰动,排名没有变化,梯度是明确定义的。为了避免未定义的梯度,大多数 autodiff 系统将隐式使用这种更安全的导数定义来进行基于秩的计算。
结果是当你对输入进行无限微扰时,输出的值不会改变,这是梯度为零的另一种说法。这并不是 autodiff 的失败——它是基于 autodiff 的微分定义的正确梯度。此外,如果您尝试在这些不连续处更改为导数的不同定义,您可能希望得到的最好结果将是未定义的输出,因此导致零的定义可以说更有用和更正确。
【讨论】:
谢谢。事后看来,我没有考虑到theta
的微小变化,这似乎是微不足道的。现在这是有道理的,尤其是当我看你提到的更简单的例子时。如果我想找到“最佳”θ,我想进行随机搜索是我最好的选择——尽管这可能是它自己的 SO 问题。以上是关于关于基于等级的计算的自动微分的主要内容,如果未能解决你的问题,请参考以下文章