谷歌JAX快速入门笔记详解和案例

Posted cui_yonghua

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了谷歌JAX快速入门笔记详解和案例相关的知识,希望对你有一定的参考价值。

一. 什么是JAX?

JAX最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起,借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。

JAX并非是一个深度学习的框架或者库,它的设计目标也并非是作为一个新的深度学习框架。

简单来说,JAX是一个包含可组合函数变换的数值计算库,只不过深度学习恰好是JAX能做的一项工作。

JAX处于函数变换(function transformations)和科学计算的交界处,所以也有能力训练神经网络模型,但不止于训练。

目前JAX在Github上已经斩获了超2万多颗star:

github地址:https://github.com/google/jax(截至目前,star数:20.3k)

官方文档:https://jax.readthedocs.io/en/latest/

JAX 是一个非常有前途的项目,并且用户一直在稳步增长。JAX 已经在深度学习、机器人 / 控制系统、贝叶斯方法和科学模拟等诸多领域得到了广泛应用。

二. 为什么应该使用JAX

JAX目前已经达到深度学习的最高水平。在当前开源的框架中,没有哪一个框架能在简洁、易用、速度这3个方面有两个能同时超过JAX。

  • 简洁:JAX的设计追求最少的封装,尽量避免重复造轮子。设计遵循tensor→variable(autograd)→module 3个由低到高的层次,分别代表高维数组(张量),自动求导(变量)和神经网络(层/模块),而且这3个抽象直接连接紧密,可以同时进行修改和操作。而tensorflow充斥着graph、operation、tensor、layer等全新的概念。JAX源码只有 tensorflow 的十分之一左右,更少的抽象、更直观的设计使得JAX的源码十分易于阅读。
  • 速度:在许多测评中,JAX的速度表现胜过TensorFlow和PyTorch等框架。
  • 易用:JAX是所有框架中面向对象设计的最优雅的一个,符合人们的思维,可以让用户专注于自己的想法,不需要考虑太多关于框架本身的束缚。

2.1 应该使用JAX的 6个原因:

  1. 加速NumPy。NumPy是用Python进行科学计算的基本软件包之一,但它只与CPU兼容。JAX提供了一个NumPy的实现(具有近乎相同的API),可以非常容易地在GPU和TPU上工作。对于许多用户来说,仅仅这一点就足以证明使用JAX的合理性。

  2. XLA,即加速线性代数(Accelerated Linear Algebra),是一个全程序优化编译器,专门为线性代数设计。JAX是建立在XLA之上的,大大提升了计算速度的上限。

  3. JIT。JAX允许用户使用XLA将函数转化为JIT(just in time)编译的版本。这意味用户可以通过给计算函数添加一个简单的函数装饰器来提高计算速度,可能是几个数量级的性能提升。

  4. 自动求导。JAX文档将JAX称为Autograd和XLA的结合体。自动求导的能力在科学计算的许多领域都是至关重要的,而JAX提供了几个强大的自动求导工具。

  5. 深度学习。虽然JAX本身不是一个深度学习框架,但它肯定为深度学习提供了一个更充分的基础。现在有许多建立在JAX之上的深度学习库,例如Flax、Haiku和Elegy。甚至有研究人员在PyTorch vs TensorFlow文章中强调JAX也是一个值得关注的「框架」,推荐其用于基于TPU的深度学习研究。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更进一步。

  6. 通用可微分编程范式。虽然可以使用JAX来构建和训练深度学习模型,但它也为通用可微分编程提供了一个框架。这意味着JAX可以通过使用基于模型的机器学习方法来解决实际问题。

2.2 不该使用JAX的情况

虽然JAX有可能极大地提高你的程序的性能,但也有几种情况下,是不适合使用JAX的。

  1. JAX仍然被官方认为是一个实验性框架,而不是一个完全成熟的Google产品,所以如果你正在考虑转移到JAX,需要慎重考虑。
  2. 在使用JAX时,调试的时间成本会更高,并且有很多bug仍然未被发现。对于那些没有牢固掌握函数式编程的人来说,使用JAX可能不值得。在开始将JAX用于正式的项目之前,请确保了解使用JAX的常见陷阱。
  3. JAX没有针对CPU计算进行优化。鉴于JAX是以「加速优先」的方式开发的,因此每个操作的调度并没有完全优化。正因为如此,在某些情况下,NumPy实际上可能比JAX更快,特别是对于小程序来说。
  4. JAX与Windows不兼容。目前在Windows上没有对JAX的支持。

三. 安装和使用

python环境下安装:
pip3 install jax
pip3 install jaxlib

能成功打印如下代码即可成功安装。

import jax
print(jax.random.PRNGKey(17))  # [ 0 17]

注意:jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。。由此特意进行了测试:

先试一下原生的numpy:

import numpy as np
import time

x = np.random.random([5000, 5000]).astype(np.float32)
try:
    st = time.time()
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: e")

执行结果如下:(耗时4.1秒)

再来试一下jax带的numpy:(耗时4.8秒)

import jax.numpy as np
from jax import random
import time

x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: e")

执行结果如下:

实践是检验真理的唯一标准。(我用的MacBook Pro做的测试)

为什么JAX没有比numpy更快呢?

以上是关于谷歌JAX快速入门笔记详解和案例的主要内容,如果未能解决你的问题,请参考以下文章

谷歌JAX快速入门笔记详解和案例

谷歌JAX快速入门笔记详解和案例

Annotaion 注解 详解案例

Vue-00-笔记

带有安全过滤器和依赖注入的 jax-rs 1.1,如何实现这一点?

『JAX中文文档』JAX快速入门