它来了,有了JAX 轻松加速机器学习

时间:2023-02-07 19:03:18

JAX是一个为高性能数值计算设计的Python库,特别是机器学习研究。它通过使用GPU来加速Python和NumPy代码。

JAX在机器学习领域崭露头角,其野心是使机器学习变得简单而高效。虽然,JAX仍然是谷歌和Deepmind的研究项目,还不是谷歌的官方产品,但已经被内部广泛使用,并被外部的ML研究人员采用。我们想提供一个关于JAX的介绍,如何安装JAX,以及它的优势和能力。

它来了,有了JAX 轻松加速机器学习

什么是机器学习的JAX?

JAX是一个为高性能数值计算设计的Python库,特别是机器学习研究。它的数值函数的API是基于NumPy的,这是一个用于科学计算的函数集合。JAX专注于通过使用XLA在GPU上编译NumPy函数来加速机器学习过程,并使用autograd来区分Python和NumPy函数,以及基于梯度的优化。JAX能够通过循环、分支、递归和闭合进行分化,并利用GPU加速轻松地获取导数的导数。JAX还支持反向传播和正向模式的微分。

当使用GPU运行你的代码时,JAX提供了卓越的性能,还有一个及时编译(JIT)选项,可以轻松加快大型项目的速度,我们将在本文后面深入探讨这个问题。

把JAX看作是一个Python库,它通过函数转换来修改NumPy和Python代码,以实现加速的机器学习。一般来说,只要你打算用GPU进行训练,计算梯度(autograd),或者使用JIT代码编译,都应该使用JAX。

为什么使用JAX?

除了与普通的CPU一起工作外,JAX的主要功能是能够与不同的处理单元(如GPU)一起完全发挥作用。这使得JAX与类似的软件包相比具有很大的优势,因为在涉及到图像和矢量处理时,使用GPU并行化可以使性能比CPU更快。

这一点极为重要,因为在使用NumPy库时,用户可以建立特殊大小的矩阵,使GPU在处理这类数据格式时更有时间效率。

这个时间差使得JAX库的速度和性能通过几个关键的实现超过了NumPy本身100倍以上。

矢量化--将多个数据作为单一指令处理,为线性代数计算和机器学习提供了巨大的速度。

代码并行化--在单个处理器上运行的串行代码,并将其分发出去的过程。这里首选GPU,因为它们有许多专门用于计算的处理器。

自动微分--非常简单和直接的微分,可以多次串联,轻松地评估高阶导数。

如何安装JAX

要安装只有CPU版本的JAX,这对于在笔记本电脑上进行本地开发可能是有用的,你可以运行

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

在Linux上,通常需要先将pip更新到支持manylinux2014轮的版本。

pip安装GPU (CUDA)

要安装支持CPU和NVIDIA GPU的JAX,你必须先安装CUDA和CuDNN,如果它们还没有被安装。与许多其他流行的深度学习系统不同,JAX并没有将CUDA或CuDNN作为pip包的一部分来捆绑。

JAX只为Linux提供预建的兼容CUDA的*,包括CUDA 11.1或更新版本,以及CuDNN 8.0.5或更新版本。其他操作系统、CUDA和CuDNN的组合也是可能的,但需要从源代码中构建。

需要CUDA 11.1或更新版本

如果你从源码构建,你可能会使用更早的CUDA版本,但是所有11.1以上的CUDA版本都有已知的错误,所以我们不会为旧的CUDA版本提供预构建的二进制文件。

预置*支持的cuDNN版本是:

cuDNN 8.2或更新版本。如果你的cuDNN安装得足够新,我们建议使用cuDNN 8.2*,因为它支持额外的功能。

cuDNN 8.0.5或更新版本。

您必须使用至少与您的CUDA工具箱对应的驱动版本一样新的NVIDIA驱动版本。例如,如果你安装了CUDA 11.4 update 4,如果在Linux上,你必须使用NVIDIA驱动470.82.01或更新版本。这是一个严格的要求,它的存在是因为JAX依赖于JIT-compiling代码;旧的驱动程序可能会导致失败。

如果你需要使用较新的CUDA工具包和较旧的驱动程序,例如在一个不能轻易更新NVIDIA驱动程序的集群上,你也许可以使用NVIDIA为此提供的CUDA向前兼容包。

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

jaxlib的版本必须与你要使用的现有CUDA安装的版本相对应。你可以为jaxlib明确指定一个特定的CUDA和CuDNN版本。

pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

你可以用命令找到你的CUDA版本:

nvcc --version

比较JAX和NumPy

由于JAX是一个增强的NumPy,它们的语法非常相似,使用户有能力在NumPy或JAX不执行的项目中交替使用这两种方法。这通常是在较小的项目中,加速的数量在节省的时间上是可以忽略不计的。然而,随着模型越来越大,你越应该考虑JAX。

它来了,有了JAX 轻松加速机器学习

使用JAX与NumPy进行两个矩阵的乘法运算

为了清楚地说明这两个库的速度差异,我们将使用这两个库将两个矩阵相乘,然后检查仅CPU和GPU的性能差异。我们还将检查由JIT编译器引起的性能提升。

为了继续学习本教程,请安装并导入JAX和NumPy库(来自前一步)。你可以在Kaggle或Google Colab等网站上测试你的代码。与任何库一样,你应该在代码的开头写上以下几行来导入JAX。

import jax.numpy as jnp
from jax import random

你也可以用类似的方式导入NumPy库:

import numpy as np

接下来,我们将使用CPU和GPU比较JAX和Numpy的性能,在Python中把两个矩阵相乘。对于这些基准测试,越低越好。

CPU上的NumPy

首先,我们将使用NumPy创建一个5,000乘以5,000的矩阵,并测试其速度方面的性能。

import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

每循环785毫秒

在NumPy上运行的代码的单次循环,每次循环的时间约为750毫秒。

CPU上的JAX

现在让我们运行同样的代码,但这次是使用JAX库。

import jax.numpy as jnp

size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

每个循环1.43秒

正如你所看到的,比较JAX和NumPy的纯CPU性能表明,NumPy是更快的选择。虽然JAX在普通的CPU上可能无法提供最好的性能,但它在GPU上确实提供了更好的性能。

使用GPU的JAX

现在,让我们尝试创建同样的5,000乘5,000的矩阵,这次使用JAX与GPU而不是普通CPU。

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

每循环80.6毫秒

正如清楚显示的那样,当在GPU上而不是CPU上运行JAX时,我们实现了更好的时间,每循环约80ms(大约15倍的性能)。当使用更大的矩阵或时间尺度时,这将更容易看到。

及时编译(JIT)

使用jit命令,我们的代码将使用特定的XLA编译器进行编译,使我们的函数能够有效执行。

XLA是加速线性代数的简称,被JAX和Tensorflow等库用来在GPU上编译和运行代码,效率更高。因此,总结起来,XLA是一个特定的线性代数编译器,能够以更高的速度编译代码。

我们将使用selu_np函数测试我们的代码,该函数代表缩放指数线性单元,并检查NumPy在普通CPU上的不同时间表现,以及在GPU上用JIT运行JAX。

def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

CPU上的NumPy

首先,我们将使用NumPy库创建一个大小为1,000,000的向量。

import numpy as np

x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)

每循环8.3毫秒

GPU上的JAX与JIT

现在我们将在GPU上使用JAX和JIT来测试我们的代码。

import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x)
%time selu_jax_jit(x_jax).block_until_ready()
%timeit selu_jax_jit(x_jax).block_until_ready()

每个循环153微秒(每个循环0.153毫秒)

最后,当使用JIT编译器和GPU时,我们得到了比使用普通GPU更好的性能。正如你可以清楚地看到,差异是非常明显的,从NumPy到使用JIT的JAX,速度提高了近5000%,或者说是50倍!

把JAX看成是对NumPy的修改,以实现用GPU加速机器学习。由于NumPy只能在CPU上编译,如果你选择在GPU上执行代码,JAX就比NumPy快。作为一般规则,只要计划在GPU上使用NumPy或使用JIT代码编译,就应该使用JAX。


JAX的局限性:纯函数

JAX 转换和复杂化是为功能纯正的 Python 函数设计的。纯函数不能通过访问外部变量来改变程序的状态,也不能对诸如 print() 这样的输入/输出流函数产生副作用。

连续的运行会导致这些副作用不能按预期执行。如果你不小心,未被追踪的副作用可能会使你的预期计算的准确性受到影响。

使用谷歌的JAX

它来了,有了JAX 轻松加速机器学习

在这篇文章中,我们解释了JAX的功能以及它给NumPy带来的优势。我们介绍了如何安装JAX库以及它对机器学习的优势。

然后我们继续导入JAX和NumPy。此外,我们将JAX与NumPy(这是最著名的竞争库)进行了比较,并揭示了这两者之间的时间和性能差异,使用普通的CPU和GPU以及一些JIT测试,看到了速度的大幅提高。

如果你是一个高级机器/深度学习从业者,那么在你的武器库中添加一个像JAX这样的库,它的(GPU/TPU)加速器和它的高效JIT编译器肯定会让你的生活变得更加轻松。