与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?

Posted

技术标签:

【中文标题】与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?【英文标题】:Jax/Flax (very) slow RNN-forward-pass compared to pyTorch? 【发布时间】:2021-12-14 11:01:54 【问题描述】:

我最近在 Jax 中实现了一个两层 GRU 网络,但对其性能感到失望(无法使用)。

所以,我尝试了与 Pytorch 的速度比较。

最小的工作示例

这是我的最小工作示例,输出是在 Google Colab 上使用 GPU 运行时创建的。 notebook in colab

import flax.linen as jnn 
import jax
import torch
import torch.nn as tnn
import numpy as np 
import jax.numpy as jnp

def keyGen(seed):
    key1 = jax.random.PRNGKey(seed)
    while True:
        key1, key2 = jax.random.split(key1)
        yield key2
key = keyGen(1)

hidden_size=200
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8

class RNN_jax(jnn.Module):

    @jnn.compact
    def __call__(self, x, carry_gru1, carry_gru2):
        carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
        carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
        x = jnn.Dense(4)(x)
        x = x/jnp.linalg.norm(x)
        return x, carry_gru1, carry_gru2

class RNN_torch(tnn.Module):
    def __init__(self, batch_size, hidden_size, in_features, out_features):
        super().__init__()

        self.gru = tnn.GRU(
            input_size=in_features, 
            hidden_size=hidden_size,
            num_layers=2
            )
        
        self.dense = tnn.Linear(hidden_size, out_features)

        self.init_carry = torch.zeros((2, batch_size, hidden_size))

    def forward(self, X):
        X, final_carry = self.gru(X, self.init_carry)
        X = self.dense(X)
        return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))

rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)

Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))

initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))

params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)

def forward(params, X):
    
    carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2

    Yhat = []
    for x in X: # x.shape = (batch_size, in_features)
        yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
        Yhat.append(yhat) # y.shape = (batch_size, out_features)

    #return jnp.concatenate(Y, axis=0)

jitted_forward = jax.jit(forward)

结果
# uncompiled jax version
%time forward(params, Xj)

CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s

# time for compiling
%time jitted_forward(params, Xj)

CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s

# compiled jax version
%timeit jitted_forward(params, Xj)

The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 µs per loop

# torch version
%timeit lambda: rnn_torch(Xt)

10000000 loops, best of 5: 65.7 ns per loop

问题

为什么我的 Jax 实现如此缓慢?我做错了什么?

另外,为什么编译需要这么长时间?序列没那么长..

谢谢你:)

【问题讨论】:

你的例子有一些未定义的变量,即nnn_hidden 哦,对不起。我修好了。 【参考方案1】:

JAX 代码编译缓慢的原因是在 JIT 编译期间 JAX 展开循环。所以在 XLA 编译方面,你的函数其实很大:你调用rnn_jax.apply() 1000 次,编译时间往往是语句数量的大致二次方。

相比之下,您的 pytorch 函数不使用 Python 循环,因此在底层它依赖于运行速度更快的矢量化操作。

任何时候你在 Python 中对数据使用 for 循环,一个不错的选择是你的代码会很慢:无论你使用的是 JAX、torch、numpy、pandas 等,这都是真的。我建议在 JAX 中找到一种解决问题的方法,该方法依赖于矢量化操作,而不是依赖于慢速 Python 循环。

【讨论】:

这是有道理的,但是你将如何对 RNN 进行矢量化。序列维度取决于每个先前的计算。我不确定 PyTorch 是如何做到的,但它有一个内置的序列维度,因此我不需要循环。 在 JAX 中,我认为 fori_loop 可能是您最好的选择。但值得检查亚麻示例以了解其通常是如何完成的。 哇哦。我想我想通了。您给 rnn_jax.apply(X) 的 X 中的任何其他维度都会以某种方式被减少,就好像它是一个序列维度一样。所以就像在 pyTorch 中一样。我不知道应该如何从文档中知道这一点。如果你好奇的话,也许我今天下午重做速度运行并更新结果。

以上是关于与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?的主要内容,如果未能解决你的问题,请参考以下文章

干货丨新手必读,PyTorch与TensorFlow的全方位对比

PyTorch笔记 - SwinTransformer的原理与实现

PyTorch笔记 - SwinTransformer的原理与实现

PyTorch笔记 - SwinTransformer的原理与实现

pytorch张量数据索引切片与维度变换操作大全(非常全)

PyTorch学习笔记 7.TextCNN文本分类