JAX 术语表
数组
JAX 的 numpy.ndarray
的类比。见 jax.Array
。
缩写Central Processing Unit,CPU 是大多数计算机中可用的标准计算架构。JAX 可以在 CPU 上运行计算,但通常在 GPU 和 TPU 上可以实现更好的性能。
设备
用于指代 JAX 用于执行计算的 CPU、GPU 或 TPU 的通用名称。
forward-mode autodiff
见 JVP
函数式编程
一种编程范式,程序通过应用和组合纯函数定义。JAX 设计用于函数式程序。
GPU
缩写Graphical Processing Unit,GPU 最初专门用于图像渲染相关操作,但现在更通用。JAX 能够针对 GPU 进行快速数组操作(参见 CPU 和 TPU)。
jaxpr
缩写JAX Expression,jaxpr 是由 JAX 生成的计算的中间表示形式,转发到 XLA 进行编译和执行。详见 Understanding Jaxprs 以获取更多讨论和示例。
JIT
缩写Just In Time 编译,JIT 在 JAX 中通常指将数组操作编译为 XLA,通常使用 jax.jit()
完成。
JVP
缩写Jacobian Vector Product,有时也称为forward-mode 自动微分。有关详细信息,请参阅 Jacobian-Vector products (JVPs, aka forward-mode autodiff)。在 JAX 中,JVP 是通过 jax.jvp()
实现的转换。另见 VJP。
primitive
primitive 是 JAX 程序中使用的基本计算单位。jax.lax
中的大多数函数代表单个原语。在 jaxpr 中表示计算时,jaxpr 中的每个操作都是一个原语。
纯函数
纯函数是仅基于其输入生成输出且没有副作用的函数。JAX 的转换模型设计用于处理纯函数。参见 functional programming。
pytree
pytree 是一个抽象,允许 JAX 以统一的方式处理元组、列表、字典和其他更一般的包含数组值的容器。请参阅 Working with pytrees 以获取更详细的讨论。
reverse-mode autodiff
见 VJP。
SPMD
缩写Single Program Multi Data,指的是一种并行计算技术,即在不同设备(例如几个 TPU)上并行运行相同计算(例如神经网络的前向传播)对不同输入数据(例如批处理中的不同输入)的计算。jax.pmap()
是实现 SPMD 并行性的 JAX 转换。
static
在 JIT 编译中,未被追踪的值(参见 Tracer)。有时也指静态值的编译时计算。
TPU
张量处理单元 的缩写,TPUs 是专门为深度学习应用中的 N 维张量快速运算而设计的芯片。JAX 能够针对 TPUs 进行快速数组操作(另见 CPU 和 GPU)。
追踪器
一个用作 JAX 数组替身的对象,以确定 Python 函数执行的操作序列。在内部,JAX 通过 jax.core.Tracer
类实现此功能。
转换
高阶函数:即接受函数作为输入并输出转换后函数的函数。在 JAX 中的示例包括 jax.jit()
、jax.vmap()
和 jax.grad()
。
VJP
向量雅可比积,有时也称为反向模式自动微分。有关详细信息,请参阅向量雅可比积(VJPs,又称反向模式自动微分)。在 JAX 中,VJP 是通过 jax.vjp()
实现的转换。还请参阅 JVP。
XLA
加速线性代数 的缩写,XLA 是一个专用于线性代数操作的编译器,是 JIT 编译 JAX 代码的主要后端。请参阅 www.tensorflow.org/xla/
。
弱类型
JAX 数据类型,其类型提升语义与 Python 标量相同;请参阅 JAX 中的弱类型值。