为啥我的 JAX + Haiku 代码不在 GPU 上运行?

Posted

技术标签:

【中文标题】为啥我的 JAX + Haiku 代码不在 GPU 上运行?【英文标题】:Why does my JAX + Haiku code don't run on GPU?为什么我的 JAX + Haiku 代码不在 GPU 上运行? 【发布时间】:2021-12-23 14:52:47 【问题描述】:

下午好, 我刚开始学习 JAX 和 Haiku,但我无法在 GPU 上运行我的代码。我在激活 GPU 的情况下在 Google Colab 和 Kaggle 笔记本中运行我的代码,但它比停用 GPU 需要更多时间。 此外,当我查看 GPU 指标时,我发现我只使用了 1% 的计算能力,但使用了 90% 的 GPU 内存。 这是我的代码(MNIST 的 MLP):

def mlp(images):
  model = hk.Sequential([hk.Linear(128),
                      jax.nn.relu,
                      hk.Linear(64),
                      jax.nn.relu,
                      hk.Linear(10),
                      jax.nn.log_softmax])
  return model(images)

def loss(params, model, images, labels):
  logits = model.apply(params = params, images = images)
  labels = jax.nn.one_hot(labels, num_classes = 10)
  cross_entropy_loss = -jnp.sum(labels*logits)/len(labels)
  return cross_entropy_loss
# Initializing the MLP model
mlp = hk.without_apply_rng(hk.transform(mlp))
params = mlp.init(rng = jax.random.PRNGKey(0),
                  images = next(iter(train_loader))[0])

# Initializing the optimizer
opt = optax.adam(1e-4)
opt_state = opt.init(params = params)
@jax.jit
def update(params, opt_state, images, labels):
  grads = jax.grad(loss)(params,mlp,images,labels)
  updates, opt_state = opt.update(grads, opt_state)
  return optax.apply_updates(params, updates), opt_state

def train(params, opt_state, epochs):
  for epoch in range(epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
      if batch_idx == 0:
        print(f"Epoch epoch : loss = loss(params,mlp,images,labels)")
      params, opt_state = update(params, opt_state, images,labels)

%time train(params, opt_state, epochs = 10)

如果你知道我做错了什么,你会帮我很多。 谢谢。

【问题讨论】:

【参考方案1】:

这个问题很难回答,因为不清楚epochstrain_loader 包含什么。但一般的回应:

默认情况下,JAX 将始终在启动时预分配 90% 的 GPU 内存(请参阅 GPU Memory Allocation),因此这并不表示您的计算消耗了多少内存 for 循环等 Python 控制流将在您的 CPU 上执行,将内部计算一一分派到 GPU。除非程序遇到阻塞调用,例如打印计算结果,否则此调度将尽可能异步(请参阅Asynchronous Dispatch)。

鉴于这些事实,我怀疑您的代码运行缓慢并且没有使 GPU 饱和的原因是因为每个 update 操作是一个非常小的计算,因此每个循环内的调度开销占主导地位。调度开销通常是由于设备传输造成的(即,如果epochstrain_loader 的内容尚未存在于 GPU 上)。由于异步调度,如果您避免阻塞调用(例如在循环中打印损失函数),这种调度开销的累积效应可能不会成为问题。更好的解决方案可能是将循环推送到 XLA 中(如果循环数量很少,则通过 JIT 编译整个训练过程,或者如果循环数量很大,则使用 lax control flow),但这取决于epochstrain_loader 以及数据是保存在设备上还是需要传输。

【讨论】:

【参考方案2】:

几天前我遇到了这个问题,我的进程在 cpu 上运行,但使用了这么多 GPU 内存。 原因可能是你安装了错误的 jax 和 jaxlib 版本,它只是 cpu,你可以通过安装 gpu 版本来解决这个问题,如下所示:

   pip install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

注意!!!你最好检查你的服务器/计算机的 cuda 驱动程序版本 此外,您可以浏览 https://github.com/google/jax

你会知道更多细节

【讨论】:

以上是关于为啥我的 JAX + Haiku 代码不在 GPU 上运行?的主要内容,如果未能解决你的问题,请参考以下文章

火炬代码不在 GPU 上运行

Haiku Generator - 如何运行这个脚本? [复制]

为啥我的 colab 笔记本不使用 GPU?

为啥在这段代码中 CPU 运行速度比 GPU 快?

为啥我的 Kmeans CuPy 代码中有“OutOfMemoryError”?

与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?