Jax、jit 和动态形状:Tensorflow 的回归?
Posted
技术标签:
【中文标题】Jax、jit 和动态形状:Tensorflow 的回归?【英文标题】:Jax, jit and dynamic shapes: a regression from Tensorflow? 【发布时间】:2021-06-17 01:41:48 【问题描述】:documentation for JAX 说,
并非所有 JAX 代码都可以 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。
现在我有点惊讶,因为 tensorflow 有像 tf.boolean_mask
这样的操作,它可以做 JAX 在编译时似乎无法做的事情。
-
为什么 TensorFlow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能弄错了。我不记得 Tensorflow 曾经在动态形状方面遇到过问题,而像
tf.boolean_mask
这样的函数一直存在。
我们可以期待这个差距在未来缩小吗?如果不是,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)所支持的功能?
编辑
渐变通过tf.boolean_mask
(显然不在掩码值上,它们是离散的);此处使用值未知的 TF1 样式图为例,因此 TF 不能依赖它们:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape) # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None
【问题讨论】:
这个问题对于 *** 来说有点太主观了。在github.com/google/jax/discussions 上询问您可能会更幸运 嗨@jakevdp,我不认为这个问题是主观的,因为它与 JAX 和 TF 中动态形状的运算符的 jit 编译能力有关。我同意我的问题的标题没有反映这一点。 好吧,让我换个说法:你问的是关于 JAX 的设计和路线图的事情; *** 版主经常将此类问题作为题外话关闭,而能够回答此类问题的人在 JAX 的 github 讨论中比在这里更活跃。 哦,我明白你的意思了。很公平。 【参考方案1】:目前,您不能(这里讨论过)
这不是 JAX jit vs TensorFlow 的限制,而是 XLA 的限制,或者更确切地说是两者的编译方式。
JAX 仅使用 XLA 来编译函数。 XLA 需要知道静态形状。这是 XLA 中固有的设计选择。
TensorFlow 使用function
:这会创建一个图形,该图形可以具有静态未知的形状。这不如使用 XLA 高效,但仍然可以。但是,tf.function
提供了一个选项jit_compile
,它将使用 XLA 编译函数内部的图形。虽然这通常会提供不错的加速(免费),但它也有一些限制:形状需要静态已知(惊喜、惊喜……)
这总体上不是太令人惊讶的行为:计算机中的计算通常更快(考虑到一个体面的优化器)以前知道的越多参数越多(内存布局,...)可以优化调度。知道的越少,代码越慢(在这端是普通的 Python)。
【讨论】:
【参考方案2】:我认为 JAX 并没有比 TensorFlow 更无能为力。在 JAX 中没有什么禁止你这样做的:
new_array = my_array[mask]
但是,mask
应该是索引(整数)而不是布尔值。这样,JAX 就知道new_array
的形状(与mask
相同)。从这个意义上说,我很确定 tf.boolean_mask
是不可微分的,即如果您尝试在某个时候计算它的梯度,它会引发错误。
更一般地说,如果您需要屏蔽数组,无论您使用什么库,都有两种方法:
-
如果您事先知道需要选择哪些索引并且需要提供这些索引以便库可以在编译之前计算形状;
如果您无法定义这些索引,无论出于何种原因,您都需要设计代码以避免防止填充影响您的结果。
每种情况的示例
假设您正在 JAX 中编写一个简单的嵌入层。 input
是一组对应多个句子的标记索引。为了获得与这些索引对应的词嵌入,我将简单地写word_embeddings = embeddings[input]
。由于我事先不知道句子的长度,我需要预先将所有标记序列填充到相同的长度,这样input
的形状为(number_of_sentences, sentence_max_length)
。现在,每当这个形状发生变化时,JAX 都会编译屏蔽操作。为了尽量减少编译次数,您可以提供相同数量的句子(也称为批量大小),您可以将sentence_max_length
设置为整个语料库中的最大句子长度。这样,训练期间将只有一个编译。当然,您需要在word_embeddings
中保留一行与pad 索引对应的行。但是,掩蔽仍然有效。
在模型的后面,假设您想将每个句子的每个单词表示为句子中所有其他单词的加权平均值(如自我注意机制)。权重是为整个批次并行计算的,并存储在维度为(number_of_sentences, sentence_max_length, sentence_max_length)
的矩阵A
中。加权平均值使用公式A @ word_embeddings
计算。现在,您需要确保填充标记不会影响前面的公式。为此,您可以将 A 中与焊盘索引相对应的条目清零,以消除它们对平均的影响。如果 pad token 索引为 0,你会这样做:
mask = jnp.array(input > 0, dtype=jnp.float32)
A = A * mask[:, jnp.newaxis, :]
weighted_mean = A @ word_embeddings
所以这里我们使用了一个布尔掩码,但是掩码在某种程度上是可微的,因为我们将掩码与另一个矩阵相乘,而不是使用它作为索引。请注意,我们应该以相同的方式删除同样对应于填充标记的weighted_mean
行。
【讨论】:
感谢您的回答。也许我不明白你的评论,但是渐变 do 会通过tf.boolean_mask
(显然不是通过掩码)。我编辑了我的答案以提供一个小的说明。
jax.numpy.where 或 np.asarray(condition).nonzero() 可能是使用 JAX 执行此操作的最接近的操作,但需要形状来 jit 它们。毕竟这不是梯度的问题,而是返回形状良好的数组的问题。如果在矩阵 X = [[1,2], [3,4]] 上使用掩码 = [[True, True], [False, True]] 会发生什么?
TF 对此的回答(确实是唯一可能的回答)是展平数组和掩码的公共维度。您的示例将在构建期间具有形状 (?,)
,在运行时具有 (3,)
。
它的能力不如 tf.function without xla 因为它需要具有静态已知形状。当您在示例中编写“掩码”时,我认为您的意思是“索引”,掩码 is 带有布尔值。而且,正如所写的,“这样,JAX 就知道形状”:TF 不需要与函数一起使用。顺便说一句,这与渐变无关!这是关于编译的。所以梯度的论点(我认为也不是真的)在这里并不适用。以上是关于Jax、jit 和动态形状:Tensorflow 的回归?的主要内容,如果未能解决你的问题,请参考以下文章