开始入门
安装 JAX
使用 JAX 需要安装两个包:jax
是纯 Python 的跨平台库,jaxlib
包含编译的二进制文件,对于不同的操作系统和加速器需要不同的构建。
TL;DR 对于大多数用户来说,典型的 JAX 安装可能如下所示:
- 仅限 CPU(Linux/macOS/Windows)
pip install -U jax
- GPU(NVIDIA,CUDA 12)
pip install -U "jax[cuda12]"
- TPU(Google Cloud TPU VM)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
支持的平台
下表显示了所有支持的平台和安装选项。检查您的设置是否受支持;如果显示“是”或“实验性”,请单击相应链接以了解更详细的 JAX 安装方法。
Linux,x86_64 | Linux,aarch64 | macOS,Intel x86_64,AMD GPU | macOS,Apple Silicon,基于 ARM | Windows,x86_64 | Windows WSL2,x86_64 | |
CPU | 是 | 是 | 是 | 是 | 是 | 是 |
NVIDIA GPU | 是 | 是 | 否 | 不适用 | 否 | 实验性 |
Google Cloud TPU | 是 | 不适用 | 不适用 | 不适用 | 不适用 | 不适用 |
AMD GPU | 实验性 | 否 | 否 | 不适用 | 否 | 否 |
| Apple GPU | 不适用 | 否 | 实验性 | 实验性 | 不适用 | 不适用 | ## CPU
pip 安装:CPU
目前,JAX 团队为以下操作系统和架构发布 jaxlib
轮子:
- Linux,x86_64
- Linux, aarch64
- macOS,Intel
- macOS,基于 Apple ARM
- Windows,x86_64(实验性)
要安装仅 CPU 版本的 JAX,可能对于在笔记本电脑上进行本地开发非常有用,您可以运行:
pip install --upgrade pip pip install --upgrade jax
在 Windows 上,如果尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。
其他操作系统和架构需要从源代码构建。在其他操作系统和架构上尝试 pip 安装可能导致 jaxlib
未能与 jax
一起安装(虽然 jax
可能成功安装,但在运行时可能会失败)。 ## NVIDIA GPU
JAX 支持具有 SM 版本 5.2(Maxwell)或更新版本的 NVIDIA GPU。请注意,由于 NVIDIA 在其软件中停止了对 Kepler 系列 GPU 的支持,JAX 不再支持 Kepler 系列 GPU。
您必须先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但驱动版本必须 >= 525.60.13 才能在 Linux 上运行 CUDA 12。
如果您需要在较老的驱动程序上使用更新的 CUDA 工具包,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 专门为此目的提供的 CUDA 向前兼容包。
pip 安装:NVIDIA GPU(通过 pip 安装,更加简便)
有两种安装 JAX 并支持 NVIDIA GPU 的方式:
- 使用从 pip 轮子安装的 NVIDIA CUDA 和 cuDNN
- 使用自行安装的 CUDA/cuDNN
JAX 团队强烈建议使用 pip wheel 安装 CUDA 和 cuDNN,因为这样更加简单!
NVIDIA 仅为 x86_64 和 aarch64 平台发布了 CUDA pip 包;在其他平台上,您必须使用本地安装的 CUDA。
pip install --upgrade pip # NVIDIA CUDA 12 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda12]"
如果 JAX 检测到错误版本的 NVIDIA CUDA 库,您需要检查以下几点:
- 请确保未设置
LD_LIBRARY_PATH
,因为LD_LIBRARY_PATH
可能会覆盖 NVIDIA CUDA 库。 - 确保安装的 NVIDIA CUDA 库与 JAX 请求的库相符。重新运行上述安装命令应该可以解决问题。
pip 安装:NVIDIA GPU(本地安装的 CUDA,更为复杂)
如果您想使用预安装的 NVIDIA CUDA 副本,您必须首先安装 NVIDIA 的 CUDA和 cuDNN。
JAX 仅为 Linux x86_64 和 Linux aarch64 提供预编译的 CUDA 兼容 wheel。其他操作系统和架构的组合也可能存在,但需要从源代码构建(请参考构建指南以了解更多信息)。
您应该使用至少与您的NVIDIA CUDA toolkit 对应的驱动版本相同的 NVIDIA 驱动程序版本。例如,在无法轻易更新 NVIDIA 驱动程序的集群上需要使用更新的 CUDA 工具包,您可以使用 NVIDIA 为此目的提供的CUDA 向前兼容包。
JAX 目前提供一种 CUDA wheel 变体:
Built with | Compatible with |
CUDA 12.3 | CUDA >=12.1 |
CUDNN 9.0 | CUDNN >=9.0, <10.0 |
NCCL 2.19 | NCCL >=2.18 |
JAX 检查您的库的版本,如果版本不够新,则会报错。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK
环境变量将禁用此检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。
NCCL 是一个可选依赖项,仅在执行多 GPU 计算时才需要。
安装方法如下:
pip install --upgrade pip # Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. pip install --upgrade "jax[cuda12_local]"
这些 pip
安装在 Windows 上无法工作,并可能静默失败;请参考上表。
您可以使用以下命令查找您的 CUDA 版本:
nvcc --version
JAX 使用 LD_LIBRARY_PATH
查找 CUDA 库,并使用 PATH
查找二进制文件(ptxas
、nvlink
)。请确保这些路径指向正确的 CUDA 安装位置。
如果在使用预编译的 wheel 时遇到任何错误或问题,请在GitHub 问题跟踪器上告知 JAX 团队。
NVIDIA GPU Docker 容器
NVIDIA 提供了JAX 工具箱容器,这些是 bleeding edge 容器,包含 jax 的夜间版本和一些模型/框架。 ## Google Cloud TPU
pip 安装:Google Cloud TPU
JAX 为 Google Cloud TPU 提供预构建的安装包。要在云 TPU VM 中安装 JAX 及相应版本的 jaxlib
和 libtpu
,您可以运行以下命令:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
对于 Colab 的用户(https://colab.research.google.com/),请确保您使用的是 TPU v2 而不是已过时的旧 TPU 运行时。## Apple Silicon GPU(基于 ARM 的)
pip 安装:Apple 基于 ARM 的 Silicon GPU
Apple 为基于 ARM 的 GPU 硬件提供了一个实验性的 Metal 插件。详情请参阅 Apple 的 JAX on Metal 文档。
注意: Metal 插件存在一些注意事项:
- Metal 插件是新的实验性质,并存在一些已知问题,请在 JAX 问题跟踪器上报告任何问题。
- 当前的 Metal 插件需要非常特定版本的
jax
和jaxlib
。随着插件 API 的成熟,此限制将逐步放宽。## AMD GPU
JAX 具有实验性的 ROCm 支持。有两种安装 JAX 的方法:
- 使用 AMD 的 Docker 容器;或者
- 从源代码构建(参见从源代码构建 —— 一个名为 Additional notes for building a ROCM
jaxlib
for AMD GPUs 的部分)。
Conda(社区支持)
Conda 安装
存在一个社区支持的 jax
的 Conda 构建。要使用 conda
安装它,只需运行:
conda install jax -c conda-forge
要在带有 NVIDIA GPU 的机器上安装它,请运行:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
请注意,由 conda-forge
分发的 cudatoolkit
缺少 JAX 所需的 ptxas
。因此,您必须从 nvidia
渠道安装 cuda-nvcc
包,或者在您的机器上单独安装 CUDA,以便 ptxas
在您的路径中可用。上述渠道顺序很重要(conda-forge
在 nvidia
之前)。
如果您希望覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 版本,请按照 conda-forge
网站上“技巧和技巧”部分的说明操作。
前往 conda-forge
的 jaxlib 和 jax 存储库获取更多详细信息。
JAX 夜间安装
夜间版本反映了它们构建时主 JAX 存储库的状态,并且可能无法通过完整的测试套件。
- 仅限 CPU:
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
- Google Cloud TPU:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- NVIDIA GPU(CUDA 12):
pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
- NVIDIA GPU(CUDA 12)遗留:
用于历史 nightly 版本的单片 CUDA jaxlibs。您很可能不需要此选项;不会再构建更多的单片 CUDA jaxlibs,并且现有的将在 2024 年 9 月到期。请使用上面的“CUDA 12”选项。
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
从源代码构建 JAX
参考从源代码构建。
安装旧版本的 jaxlib
wheels
由于 Python 软件包索引上的存储限制,JAX 团队定期从 http://pypi.org/project/jax 的发布中删除旧的jaxlib
安装包。但是您仍然可以通过这里的 URL 直接安装它们。例如:
# Install jaxlib on CPU via the wheel archive pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html # Install the jaxlib 0.3.25 CPU wheel directly pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
对于特定的旧 GPU 安装包,请确保使用jax_cuda_releases.html
的 URL;例如
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
快速入门
JAX 是一个面向数组的数值计算库(à la NumPy),具有自动微分和 JIT 编译功能,以支持高性能的机器学习研究。
本文档提供了 JAX 主要功能的快速概述,让您可以快速开始使用 JAX:
- JAX 提供了一个统一的类似于 NumPy 的接口,用于在 CPU、GPU 或 TPU 上运行的计算,在本地或分布式设置中。
- JAX 通过 Open XLA 内置了即时编译(JIT)功能,这是一个开源的机器学习编译器生态系统。
- JAX 函数支持通过其自动微分转换有效地评估梯度。
- JAX 函数可以自动向量化,以有效地将它们映射到表示输入批次的数组上。
安装
可以直接从 Python Package Index 安装 JAX 用于 Linux、Windows 和 macOS 上的 CPU:
pip install jax
或者,对于 NVIDIA GPU:
pip install -U "jax[cuda12]"
如需更详细的特定平台安装信息,请查看安装 JAX。
JAX 就像 NumPy 一样
大多数 JAX 的使用是通过熟悉的 jax.numpy
API 进行的,通常在 jnp
别名下导入:
import jax.numpy as jnp
通过这个导入,您可以立即像使用典型的 NumPy 程序一样使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和操作符,以及数组属性和方法:
def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = jnp.arange(5.0) print(selu(x))
[0\. 1.05 2.1 3.1499999 4.2 ]
一旦您开始深入研究,您会发现 JAX 数组和 NumPy 数组之间存在一些差异;这些差异在 🔪 JAX - The Sharp Bits 🔪 中进行了探讨。
使用jax.jit()
进行即时编译
JAX 可以在 GPU 或 TPU 上透明运行(如果没有,则退回到 CPU)。然而,在上述示例中,JAX 是一次将核心分派到芯片上的操作。如果我们有一系列操作,我们可以使用 jax.jit()
函数将这些操作一起编译为 XLA。
我们可以使用 IPython 的 %timeit
快速测试我们的 selu
函数,使用 block_until_ready()
来考虑 JAX 的动态分派(请参阅异步分派):
from jax import random key = random.key(1701) x = random.normal(key, (1_000_000,)) %timeit selu(x).block_until_ready()
2.84 ms ± 9.23 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(请注意,我们已经使用 jax.random
生成了一些随机数;有关如何在 JAX 中生成随机数的详细信息,请查看伪随机数)。
我们可以使用 jax.jit()
转换来加速此函数的执行,该转换将在首次调用 selu
时进行 JIT 编译,并在此后进行缓存。
from jax import jit selu_jit = jit(selu) _ = selu_jit(x) # compiles on first call %timeit selu_jit(x).block_until_ready()
844 μs ± 2.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
上述时间表示在 CPU 上执行,但同样的代码可以在 GPU 或 TPU 上运行,通常会有更大的加速效果。
欲了解更多关于 JAX 中 JIT 编译的信息,请查看即时编译。
使用 jax.grad()
计算导数
除了通过 JIT 编译转换函数外,JAX 还提供其他转换功能。其中一种转换是 jax.grad()
,它执行自动微分 (autodiff):
from jax import grad def sum_logistic(x): return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) x_small = jnp.arange(3.) derivative_fn = grad(sum_logistic) print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
让我们用有限差分来验证我们的结果是否正确。
def first_finite_differences(f, x, eps=1E-3): return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))]) print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761 0.10502338]
grad()
和 jit()
转换可以任意组合并混合使用。在上面的示例中,我们对 sum_logistic
进行了 JIT 编译,然后取了它的导数。我们可以进一步进行:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256
除了标量值函数外,jax.jacobian()
转换还可用于计算向量值函数的完整雅可比矩阵:
from jax import jacobian print(jacobian(jnp.exp)(x_small))
[[1\. 0\. 0\. ] [0\. 2.7182817 0\. ] [0\. 0\. 7.389056 ]]
对于更高级的自动微分操作,您可以使用 jax.vjp()
来进行反向模式向量-雅可比积分,以及使用 jax.jvp()
和 jax.linearize()
进行正向模式雅可比-向量积分。这两者可以任意组合,也可以与其他 JAX 转换组合使用。例如,jax.jvp()
和 jax.vjp()
用于定义正向模式 jax.jacfwd()
和反向模式 jax.jacrev()
,用于计算正向和反向模式下的雅可比矩阵。以下是组合它们以有效计算完整 Hessian 矩阵的一种方法:
from jax import jacfwd, jacrev def hessian(fun): return jit(jacfwd(jacrev(fun))) print(hessian(sum_logistic)(x_small))
[[-0\. -0\. -0\. ] [-0\. -0.09085776 -0\. ] [-0\. -0\. -0.07996249]]
这种组合在实践中产生了高效的代码;这基本上是 JAX 内置的 jax.hessian()
函数的实现方式。
想了解更多关于 JAX 中的自动微分,请查看自动微分。
使用 jax.vmap()
进行自动向量化
另一个有用的转换是 vmap()
,即向量化映射。它具有沿数组轴映射函数的熟悉语义,但与显式循环函数调用不同,它将函数转换为本地向量化版本,以获得更好的性能。与 jit()
组合时,它可以与手动重写函数以处理额外批处理维度的性能相媲美。
我们将处理一个简单的示例,并使用 vmap()
将矩阵-向量乘法提升为矩阵-矩阵乘法。虽然在这种特定情况下手动完成这一点很容易,但相同的技术也适用于更复杂的函数。
key1, key2 = random.split(key) mat = random.normal(key1, (150, 100)) batched_x = random.normal(key2, (10, 100)) def apply_matrix(x): return jnp.dot(mat, x)
apply_matrix
函数将一个向量映射到另一个向量,但我们可能希望将其逐行应用于矩阵。在 Python 中,我们可以通过循环遍历批处理维度来实现这一点,但通常导致性能不佳。
def naively_batched_apply_matrix(v_batched): return jnp.stack([apply_matrix(v) for v in v_batched]) print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched 962 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
熟悉 jnp.dot
函数的程序员可能会意识到,可以重写 apply_matrix
来避免显式循环,利用 jnp.dot
的内置批处理语义:
import numpy as np @jit def batched_apply_matrix(batched_x): return jnp.dot(batched_x, mat.T) np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Manually batched') %timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched 14.3 μs ± 28.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
然而,随着函数变得更加复杂,这种手动批处理变得更加困难且容易出错。vmap()
转换旨在自动将函数转换为支持批处理的版本:
from jax import vmap @jit def vmap_batched_apply_matrix(batched_x): return vmap(apply_matrix)(batched_x) np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap 21.7 μs ± 98.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
正如您所预期的那样,vmap()
可以与 jit()
、grad()
和任何其他 JAX 转换任意组合。
想了解更多关于 JAX 中的自动向量化,请查看自动向量化。
这只是 JAX 能做的一小部分。我们非常期待看到你用它做些什么!
JAX 中文文档(一)(2)https://developer.aliyun.com/article/1559830