为啥我的 JAX + Haiku 代码不在 GPU 上运行?
Posted
技术标签:
【中文标题】为啥我的 JAX + Haiku 代码不在 GPU 上运行?【英文标题】:Why does my JAX + Haiku code don't run on GPU?为什么我的 JAX + Haiku 代码不在 GPU 上运行? 【发布时间】:2021-12-23 14:52:47 【问题描述】:下午好, 我刚开始学习 JAX 和 Haiku,但我无法在 GPU 上运行我的代码。我在激活 GPU 的情况下在 Google Colab 和 Kaggle 笔记本中运行我的代码,但它比停用 GPU 需要更多时间。 此外,当我查看 GPU 指标时,我发现我只使用了 1% 的计算能力,但使用了 90% 的 GPU 内存。 这是我的代码(MNIST 的 MLP):
def mlp(images):
model = hk.Sequential([hk.Linear(128),
jax.nn.relu,
hk.Linear(64),
jax.nn.relu,
hk.Linear(10),
jax.nn.log_softmax])
return model(images)
def loss(params, model, images, labels):
logits = model.apply(params = params, images = images)
labels = jax.nn.one_hot(labels, num_classes = 10)
cross_entropy_loss = -jnp.sum(labels*logits)/len(labels)
return cross_entropy_loss
# Initializing the MLP model
mlp = hk.without_apply_rng(hk.transform(mlp))
params = mlp.init(rng = jax.random.PRNGKey(0),
images = next(iter(train_loader))[0])
# Initializing the optimizer
opt = optax.adam(1e-4)
opt_state = opt.init(params = params)
@jax.jit
def update(params, opt_state, images, labels):
grads = jax.grad(loss)(params,mlp,images,labels)
updates, opt_state = opt.update(grads, opt_state)
return optax.apply_updates(params, updates), opt_state
def train(params, opt_state, epochs):
for epoch in range(epochs):
for batch_idx, (images, labels) in enumerate(train_loader):
if batch_idx == 0:
print(f"Epoch epoch : loss = loss(params,mlp,images,labels)")
params, opt_state = update(params, opt_state, images,labels)
%time train(params, opt_state, epochs = 10)
如果你知道我做错了什么,你会帮我很多。 谢谢。
【问题讨论】:
【参考方案1】:这个问题很难回答,因为不清楚epochs
或train_loader
包含什么。但一般的回应:
for
循环等 Python 控制流将在您的 CPU 上执行,将内部计算一一分派到 GPU。除非程序遇到阻塞调用,例如打印计算结果,否则此调度将尽可能异步(请参阅Asynchronous Dispatch)。
鉴于这些事实,我怀疑您的代码运行缓慢并且没有使 GPU 饱和的原因是因为每个 update
操作是一个非常小的计算,因此每个循环内的调度开销占主导地位。调度开销通常是由于设备传输造成的(即,如果epochs
或train_loader
的内容尚未存在于 GPU 上)。由于异步调度,如果您避免阻塞调用(例如在循环中打印损失函数),这种调度开销的累积效应可能不会成为问题。更好的解决方案可能是将循环推送到 XLA 中(如果循环数量很少,则通过 JIT 编译整个训练过程,或者如果循环数量很大,则使用 lax control flow),但这取决于epochs
和 train_loader
以及数据是保存在设备上还是需要传输。
【讨论】:
【参考方案2】:几天前我遇到了这个问题,我的进程在 cpu 上运行,但使用了这么多 GPU 内存。 原因可能是你安装了错误的 jax 和 jaxlib 版本,它只是 cpu,你可以通过安装 gpu 版本来解决这个问题,如下所示:
pip install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
注意!!!你最好检查你的服务器/计算机的 cuda 驱动程序版本 此外,您可以浏览 https://github.com/google/jax
你会知道更多细节
【讨论】:
以上是关于为啥我的 JAX + Haiku 代码不在 GPU 上运行?的主要内容,如果未能解决你的问题,请参考以下文章
Haiku Generator - 如何运行这个脚本? [复制]