使用 JAX 进行梯度累积

Posted

技术标签:

【中文标题】使用 JAX 进行梯度累积【英文标题】:Gradient Accumulation with JAX 【发布时间】:2021-09-02 02:02:55 【问题描述】:

我做了一个简单的脚本来尝试用 JAX 进行梯度累积。这个想法是让大批量大小(例如 64 个)分成适合 GPU 内存的小块(例如 4 个)。对于每个块,将存储在 pytree 中的结果梯度添加到当前批次梯度中。仅当计算大批量的所有块时才完成更新。在这个特定示例中,我们只是尝试将随机 512 维向量拟合到具有线性层的随机布尔值。这是脚本:

import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass

@dataclass
class Jax_model:
    init_fun: Callable
    apply_fun: Callable


def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):

    def init_fun(key):
        key, sub_key1, sub_key2 = jax.random.split(key, 3)
        params = 
            'I': init_kernel(sub_key1, (input_size, output_size) ),
            'I_b': init_bias(sub_key2, (1,output_size) ),
        
        return params

    def apply_fun(params, inputs):
        I, I_b, = params['I'], params['I_b']
        logits = inputs @ I + I_b
        return logits

    return Jax_model(init_fun, apply_fun)


def divide_pytree(pytree, div):
    for pt in jax.tree_util.tree_leaves(pytree):
        pt = pt / div
    return pytree


def add_pytrees(pytree1, pytree2):
    for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
        pt1 = pt1 + pt2
    return pytree1


rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50

model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)

@jit
def update(i, current_opt_state, current_batch):
    N = current_batch[0].shape[0]
    K = accumulation_size
    num_gradients = N//K
    accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
    value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
    value = value / num_gradients
    grads = divide_pytree(grads, num_gradients)
    for k in range(K,N,K):
        accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
        new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
        value = value + (new_value / num_gradients)
        grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
    return opt_update(i, grads, current_opt_state), value

def loss_func(current_params, current_batch):
    inputs, labels = current_batch
    predictions = model.apply_fun(current_params, inputs)
    loss = jnp.square(labels-predictions).sum()
    return loss

for i in range(n_iter):
    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
    inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
    labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
    batch = inputs, labels
    opt_state, batch_loss = update(i, opt_state, batch)
    print(i, batch_loss)

我对@9​​87654324@ 和add_pytrees 有疑问。它是否真的修改了当前的批次梯度或者我错过了什么?此外,您是否看到此代码有任何速度问题?特别是,我应该使用jax.lax.fori_loop 代替传统的python for 循环吗?

相关链接:

https://github.com/google/jax/issues/1488 https://github.com/google-research/long-range-arena/issues/4

【问题讨论】:

【参考方案1】:

关于 pytree 计算:正如所写,您的函数返回未修改的输入。更好的方法是使用jax.tree_util.tree_map;例如:

from jax.tree_util import tree_map

def divide_pytree(pytree, div):
  return tree_map(lambda pt: pt / div, pytree)

def add_pytrees(pytree1, pytree2):
  return tree_map(lambda pt1, pt2: pt1 + pt2, pytree1, pytree2)

关于性能:for 循环中的任何内容在 JIT 编译时都将被展平,每次循环迭代都会重复复制所有 XLA 指令。如果您有 5 次迭代,那并不是真正的问题。如果您有 5000 个,那将显着减慢编译时间(因为 XLA 需要分析和优化循环中指令的 5000 个显式副本)。

fori_loop 可以提供帮助,但不会产生最佳代码,尤其是在 CPU 和 GPU 上运行时。

最好在可能的情况下使用广播或 vmapped 操作来表达循环的逻辑,而无需显式循环。

【讨论】:

另外,你的函数add_pytrees应该使用tree_multimap而不是tree_map来处理2个pytrees! 我不明白你的第一个问题。至于你的第二个问题,tree_multimap 是 recently deprecated 因为tree_map 做同样的事情。 对,我有一个旧版本的 JAX,感谢您的澄清。不知道fori_loop会不会加快编译速度? fori_loop 会加快编译速度。在跟踪/编译期间,Python 循环是扁平的,这意味着它们可以生成非常长的中间表示。

以上是关于使用 JAX 进行梯度累积的主要内容,如果未能解决你的问题,请参考以下文章

如何在 TF 2.0 / 1.14.0-eager 和自定义训练循环(梯度磁带)中执行梯度累积?

查找函数的梯度:Sympy vs. Jax

TensorFlow:去除累积渐变中的nans

梯度下降优化方法 与 自动控制 的关系

hog

使用 numpy 和 jax 进行非传递子类化