谷歌JAX:下一代深度学习框架大战的前夜

Posted ScorpioDoctor

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了谷歌JAX:下一代深度学习框架大战的前夜相关的知识,希望对你有一定的参考价值。

机器学习工具的革命:下一代深度学习框架大战的前夜

本文来自 人工智能社区 http://studyai.com/article/a6b60b7f

1、介绍

让我们回到一个更简单的时代,当所有人在机器学习中谈论的都是支持向量机和增强树(Boosted Tree),而Andrew Ng将神经网络引入了一个在实践中可能永远不会用到的巧妙的帽子戏法。

那一年是2012年,基于计算机视觉的比赛ImageNet将再次被最新的核心方法组合赢得。当然,直到几位研究人员公布了AlexNet,它的错误率几乎是竞争对手的一半,我们现在通常称之为 “深度学习”

许多人认为AlexNet是十年来最重要的科学突破之一,它无疑有助于改变ML研究的格局。然而,不需要太多的时间就可以认识到,在幕后,它只是先前迭代改进的组合,许多改进可以追溯到90年代初。AlexNet的核心“仅仅”是一个改进的LeNet3,它具有更多的层、更好的权重初始化、激活函数和数据增强功能。

2、ML研究中的工具

那么是什么让AlexNet如此引人注目呢?我相信答案在于研究人员掌握的工具,使他们能够在GPU加速器上运行人工神经网络,这在当时是一个相对新颖的想法。事实上,Alex Krizhevsky的前同事回忆说,比赛前的许多会议都是由Alex描述他在CUDA怪癖和特写方面的进步组成的。

现在让我们回到2015,当ML研究文章提交开始全面爆发,包括(重新)出现许多现在有希望的方法,如生成对抗学习,深度强化学习,元学习,自我监督学习,联合学习,神经架构搜索,神经微分方程,神经图网络,等等。

有人可能会说,这只是人工智能炒作的自然结果。然而,我相信一个重要的因素是第二代通用ML框架的出现,比如TensorFlow和PyTorch,以及NVIDIA在 AI上的全面应用。以前存在的框架,如Caffe和Theano,很难使用,也很难扩展,这就减缓了新思想的研究和开发。

3、革命的必要性

TensorFlow和PyTorch无疑是一个积极的网络,团队努力改进库。最近,他们交付了TensorFlow 2.0,它有一个更直接的接口和eager 模式,Pythorch 1.0 提供了计算图的JIT编译以及对基于XLA的加速器(如TPU)的支持。然而,这些框架也开始达到它们的极限,迫使研究人员进入一些途径,同时关闭其他途径,就像他们的前辈一样。

AlphaStar和OpenAI等备受瞩目的DRL项目不仅利用了大规模的计算集群,还通过结合深度变换器、嵌套的递归网络、深度残差塔等,突破了深度学习体系结构组件的局限性。

Demis Hassabis 在接受泰晤士报采访时表示,DeepMind将专注于直接应用人工智能实现科学突破。我们已经可以通过他们最近发表的一些关于神经科学和蛋白质折叠的自然文章,看到这一方向的转变。即使对这些出版物作一个简短的浏览,也足以看出这些项目在工程方面需要一些非常规的方法。

在NeurIPS 2019,概率规划和贝叶斯推理是热门话题,尤其是不确定性估计和因果推理。领先的人工智能研究人员展示了他们对人工智能未来的展望。值得注意的是,Yoshua Bengio描述了向 system 2 深度学习的过渡,包括分布外泛化、稀疏图网络和因果推理。

总而言之,下一代ML工具的一些要求是:

细粒度控制流使用
非标准优化循环
作为一等公民的高阶微分
一等公民的概率规划
在一个模型中支持多个异构加速器
从单机到大型集群的无缝可扩展性

理想情况下,这些工具还应该维护一个干净、直接和可扩展的API,使科学家能够快速研究和开发他们的想法。

4、下一代工具

好消息是,今天已经有许多候选工具,是为了响应科学计算的需要而出现的。从实验项目如zygate.jl到专门语言,如halide1和DiffTaichi。有趣的是,许多项目从自动微分社区的研究人员所做的基础工作中获得灵感,这些工作与ML并行发展。

他们中的许多人在最近的NeurIPS 2019项目转型研讨会上亮相。我最兴奋的两个是 S4TFJAX 。他们都致力于把可微编程变成工具链的一个组成部分,但他们都以自己的方式,几乎是正交的。

4.1、S4TF

顾名思义,S4TF将TensorFlow ML框架与Swift编程语言紧密集成。对该项目的信任投票是由Chris Lattner领导的,他曾撰写LLVM、Clang和Swift。

Swift是一种编译编程语言,它的主要卖点之一是强大的静态类型系统。最后一部分简单地说,SWIFT包含了在Python中使用代码的验证和转换,例如在C++中的易用性。

let a: Int = 1
let b = 2
let c = "3"

print(a + b)         // 3
print(b + c)         // compilation (!) error
print(String(b) + c) // 23

Swift特性使S4TF团队能够通过在编译过程中对使用有效算法执行的计算图进行分析、验证和优化,来满足下一代列表中的许多要求。

最重要的是,自动微分的处理被卸载到编译器中。

这个项目是一项重大的工程,在投产前还有一些路要走。然而,这是一个让工程师和研究人员都去尝试的好时机,并且有可能对它的发展做出贡献。关于S4TF的工作已经在编程语言和自动微分理论的交叉点上产生了有趣的科学进展。

关于S4TF,有一件事对我来说特别突出,那就是他们的社区推广方式。例如,核心开发人员每周都会举行设计会议,任何有兴趣参加甚至参与的人都可以参加。

当然,TensorFlow本身在这个例子中得到了很好的支持。

import TensorFlow

struct Model: Layer 
    var conv = Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
    var pool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
    var flatten = Flatten<Float>()
    var dense = Dense<Float>(inputSize: 16 * 5 * 5, outputSize: 100, activation: relu)
    var logits = Dense<Float>(inputSize: 100, outputSize: 10, activation: identity)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> 
        return input.sequenced(through: conv, pool, flatten, dense, logits)
    


var model = Model()
let optimizer = RMSProp(for: model, learningRate: 3e-4, decay: 1e-6)

for batch in CIFAR10().trainDataset.batched(128) 
  let (loss, gradients) = valueWithGradient(at: model)  model in
    softmaxCrossEntropy(logits: model(batch.data), labels: batch.label)
  
  print(loss)
  optimizer.update(&model, along: gradients)

另一方面,如果一个关键特性被证明难以实现,那么对整个管道有深入的了解就特别有价值。例如,MLIR编译器框架是S4TF工作的直接结果。

虽然可微编程是核心目标,但S4TF远不止支持各种下一代ML工具(如调试器)的基础设施。例如,假设一个IDE警告用户自定义模型计算总是导致零梯度,甚至不执行它。

Python有一个围绕科学计算构建的令人难以置信的社区,S4TF团队已经明确地花了时间通过互操作性来接受它。

import Python // All that is necessary to enable the interop.

let np = Python.import("numpy") // Can import any Python module.
let plt = Python.import("matplotlib.pyplot") 

let x = np.arange(0, 10, 0.01)
plt.plot(x, np.sin(x)) // Can use the modules as if inside Python.
plt.show() // Will show the sin plot, just as you would expect.

这个项目是一项重大的工程,在投产前还有一些路要走。然而,这是一个让工程师和研究人员都去尝试的好时机,并且有可能对它的发展做出贡献。关于S4TF的工作已经在编程语言和自动微分理论的交叉点上产生了有趣的科学进展。

关于S4TF,有一件事对我来说特别突出,那就是他们的社区推广方式。例如,核心开发人员每周都会举行设计会议,任何有兴趣参加甚至参与的人都可以参加。

要了解更多有关Swift for TensorFlow的信息,以下是一些有用的资源:

Fast.ai’s Lessons 13 and 14
Design Doc: Why Swift For TensorFlow?
Model Training Tutorial
Pre-Built Google Colab Notebook
Swift Models for Popular Architectures in DL and DRL

4.2、 JAX

JAX是一个函数转换的集合,例如实时编译和自动微分,它是用一个API在XLA上实现的瘦包装器,API本质上是NumPy和SciPy的替代品。事实上,开始使用JAX的一种方法是将其视为一个加速器支持的NumPy。

import jax.numpy as np

# Will be seamlessly executed on an accelerator such as GPU/TPU.
x, w, b = np.ones((3, 1000, 1000))
y = np.dot(w, x) + b

当然,实际上,JAX远不止这些。对许多人来说,这个项目似乎是凭空出现的,但事实是,它是跨越三个项目的五年多研究的演变。值得注意的是,JAX是从Autograd(一种对本地程序代码的AD的研究)发展而来的,它概括了支持任意转换的核心思想

def f(x):
  return np.where(x > 0, x, x / (1 + np.exp(-x)))

# Note: same singular style for the API entry points.
jit_f = jax.jit(f) # Will be 10-100x faster, depending on the accelerator.
grad_f = jax.grad(f) # Will work as expected, handling both branches. 

除了上面讨论的gradjit之外,还有两个更优秀的JAX转换示例,帮助用户通过批处理维度的自动矢量化(vmap)或跨多个设备(pmap)批处理数据。

a = np.ones((100, 300))

def g(vec):
  return np.dot(a, vec)

# Suppose `z` is a batch of 10 samples of 1 x 300 vectors.
z = np.ones((10, 300))

g(z) # Will not work due to (batch) dimension mismatch (100x300 x 10x300).

vec_g = jax.vmap(g)
vec_g(z) # Will work, efficiently propagating through batch dimension.

# Manual solution requires "playing" with matrix transpositions.
np.dot(a, z.T)

这些特性一开始可能看起来令人困惑,但经过一些实践,它们变成了研究人员工具箱中不可替代的一部分。它们甚至激发了最近TensorFlow和PyTorch中类似功能的开发。

目前,JAX作者在开发新特性时似乎坚持自己的核心能力。当然,合理的方法也是其主要缺点之一:缺乏内置的神经网络组件,除了证明Stax的概念。

添加高级功能是终端用户可能参与和贡献的东西,并且赋予JAX的坚实基础,任务可能比看起来更容易。例如,现在有两个构建在JAX之上的“竞争”库,它们都是由Google研究人员开发的,使用不同的方法:TraxFlax

# Trax approach is functional.
# Note: params are stored outside and `forward` is "pure".

import jax.numpy as np
from trax.layers import base

class Linear(base.Layer):
  def __init__(self, num_units, init_fn):
    super().__init__()
    self.num_units = num_units
    self.init_fn = init_fn

  def forward(self, x, w):
    return np.dot(x, w)

  def new_weights(self, input_signature):
    w = self.init_fn((input_signature.shape, self._num_units))
    return w
# Flax approach is object-oriented, closer to PyTorch style.

import jax.numpy as np
from flax import nn

class Linear(nn.Module):
  def apply(self, x, num_units, init_fn):
    W = self.param('W', (x.shape[-1], num_units), init_fn)
    return np.dot(x, W)

尽管有些人可能更喜欢由核心开发人员认可的单一方式,但方法的多样性很好地表明这项技术是可靠的。

在JAX特性特别突出的领域,也有一些研究方向。例如,在元学习中,训练元学习器的一种常见方法是计算输入的梯度。为了有效地解决这一问题,需要一种计算梯度的替代方法——前向模式自动微分法,这在JAX中是现成的,但在其他库中要么是不存在的,要么是实验性的特性。

JAX可能比它的S4TF计数器部件更精良,更易于生产,Google研究的一些最新发展依赖于它,比如 Reformer——一种内存高效的转换器模型,能够处理一百万字的上下文窗口,同时安装在消费者的GPU上,和神经切线-一个无限宽复杂神经网络库。

该库还被更广泛的科学计算社区所接受,用于分子动力学、概率规划和约束优化等领域的工作。

要开始使用JAX并进一步阅读,请查看以下内容:

Talk: Overview by Skye Wanderman-Milne, a core developer (starts at 44:26)
Notebook: Quickstart, going over fundamental features
Notebook: Cloud TPU Playground
Blog: You don’t know JAX
Blog: Massively parallel MCMC with JAX
Blog: Differentiable Path Tracing on the GPU/TPU

5、结论

ML研究已经开始触及我们目前可以使用的工具的极限,但是一些新的、令人兴奋的候选者即将到来,比如JAX和S4TF。如果你觉得自己更像一个工程师,而不是一个研究人员,并想知道是否有一个地方值得攻入,答案是明确的:现在是攻入它的最佳时机。此外,您还有机会参与到下一代ML工具的底层!

请注意,这并不意味着TensorFlow或PyTorch会在不久的将来就马上寿终正寝。在这些成熟的、经过战斗测试的库中仍然有很多价值。毕竟,JAX和S4TF都有TensorFlow的一部分。但是,如果你即将开始一个新的研究项目,或者如果你觉得你是围绕着库的限制而不是你的想法工作,那么也许给他们一个尝试!

以上是关于谷歌JAX:下一代深度学习框架大战的前夜的主要内容,如果未能解决你的问题,请参考以下文章

深度学习框架大战:谁将夺取“深度学习工业标准”荣耀?

JAX的深度学习和科学计算

TensorFlow败给PyTorch,谷歌:未来就靠你了,JAX

第535期机器学习日报(2016-03-06) 深度学习框架大战正在进行,谁将夺取“深度学习工业标准”的荣耀?

20210219期AI简报嵌入式机器学习(TinyML)实战教程谷歌开源计算框架JAX...

干货集锦深度学习框架专题