关于基于等级的计算的自动微分

Posted

技术标签:

【中文标题】关于基于等级的计算的自动微分【英文标题】:Automatic Differentiation with respect to rank-based computations 【发布时间】:2022-01-09 19:55:08 【问题描述】:

我是自动微分编程的新手,所以这可能是一个幼稚的问题。下面是我要解决的问题的简化版本。

我有两个输入数组——一个大小为N 的向量A 和一个形状为(N, M) 的矩阵B,以及一个大小为M 的参数向量theta。我定义了一个新数组C(theta) = B * theta 来获得一个大小为N 的新向量。然后我获取落在C 上下四分位数的元素的索引,并使用它们创建一个新数组A_low(theta) = A[lower quartile indices of C]A_high(theta) = A[upper quartile indices of C]。显然这两个确实依赖于theta,但是是否可以区分A_lowA_high w.r.t theta

到目前为止,我的尝试似乎表明没有——我使用了 autograd、JAX 和 tensorflow 的 python 库,但它们都返回零梯度。 (到目前为止我尝试过的方法包括使用 argsort 或使用 tf.top_k 提取相关子数组。)

我正在寻求帮助的是证明导数未定义(或无法分析计算),或者如果它确实存在,请提供有关如何估计它的建议。我的最终目标是最小化某些功能f(A_low, A_high) wrt theta

【问题讨论】:

【参考方案1】:

这是我根据您的描述编写的 JAX 计算:

import numpy as np
import jax.numpy as jnp
import jax

N = 10
M = 20

rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))

def f(A, B, theta, k=3):
  C = B @ theta
  _, i_upper = lax.top_k(C, k)
  _, i_lower = lax.top_k(-C, k)
  return A[i_lower], A[i_upper]

x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)

导数全为零,我相信这是正确的,因为输出值的变化不依赖于theta的值变化。

但是,您可能会问,这怎么可能?毕竟,theta 进入了计算,如果你为theta 输入不同的值,你会得到不同的输出。梯度怎么可能为零?

不过,您必须记住的是,差异化并不能衡量输入是否影响输出。它测量在输入变化很小的情况下输出的变化

我们以一个稍微简单的函数为例:

import jax
import jax.numpy as jnp

A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])

def f(A, theta):
  return A[jnp.argmax(theta)]

x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)

这里将ftheta 微分的结果全为零,原因同上。为什么?如果您对theta 进行微小更改,通常不会影响theta 的排序顺序。因此,您从 A 中选择的条目不会因为 theta 的微小变化而改变,因此相对于 theta 的导数为零。

现在,您可能会争辩说,在某些情况下情况并非如此:例如,如果 theta 中的两个值非常接近,那么即使微乎其微地扰动其中一个值,它们各自的等级也可能会发生变化。这是真的,但是这个过程产生的梯度是不确定的(输出的变化相对于输入的变化并不平滑)。好消息是这种不连续性是一方面的:如果你在另一个方向上扰动,排名没有变化,梯度是明确定义的。为了避免未定义的梯度,大多数 autodiff 系统将隐式使用这种更安全的导数定义来进行基于秩的计算。

结果是当你对输入进行无限微扰时,输出的值不会改变,这是梯度为零的另一种说法。这并不是 autodiff 的失败——它是基于 autodiff 的微分定义的正确梯度。此外,如果您尝试在这些不连续处更改为导数的不同定义,您可能希望得到的最好结果将是未定义的输出,因此导致零的定义可以说更有用和更正确。

【讨论】:

谢谢。事后看来,我没有考虑到theta 的微小变化,这似乎是微不足道的。现在这是有道理的,尤其是当我看你提到的更简单的例子时。如果我想找到“最佳”θ,我想进行随机搜索是我最好的选择——尽管这可能是它自己的 SO 问题。

以上是关于关于基于等级的计算的自动微分的主要内容,如果未能解决你的问题,请参考以下文章

计算图与自动微分

如何计算微分

d2l自动微分练习

d2l自动微分练习

PyTorch自动微分基本原理

Python - Automatic Differentiation 自动微分