JAX 中文文档(十五)(3)https://developer.aliyun.com/article/1559769
jax.experimental.maps 模块
API
xmap (fun, in_axes, out_axes, *[, …]) |
为使用命名数组轴的程序分配位置签名。 |
jax.experimental.pjit 模块
API
jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)
使fun
编译并自动跨多设备分区。
注意:此函数现在等同于 jax.jit,请改用其代替。返回的函数语义与fun
相同,但编译为在多个设备(例如多个 GPU 或多个 TPU 核心)上并行运行的 XLA 计算。如果fun
的 jitted 版本无法适应单个设备的内存,或者为了通过在多个设备上并行运行每个操作来加速fun
,这将非常有用。
设备上的分区自动基于in_shardings
中指定的输入分区传播以及out_shardings
中指定的输出分区进行。这两个参数中指定的资源必须引用由jax.sharding.Mesh()
上下文管理器定义的网格轴。请注意,pjit()
应用时的网格定义将被忽略,并且返回的函数将使用每个调用站点可用的网格定义。
未经正确分区的pjit()
函数输入将自动跨设备分区。在某些情况下,确保输入已经正确预分区可能会提高性能。例如,如果将一个pjit()
函数的输出传递给另一个pjit()
函数(或者在循环中使用同一个pjit()
函数),请确保相关的out_shardings
与相应的in_shardings
匹配。
注意
多进程平台: 在诸如 TPU pods 的多进程平台上,pjit()
可用于跨所有可用设备和进程运行计算。为实现此目的,pjit()
设计为用于 SPMD Python 程序,其中每个进程运行相同的 Python 代码,以便所有进程按相同顺序运行相同的pjit()
函数。
在此配置中运行时,网格应包含跨所有进程的设备。所有输入参数必须具有全局形状。fun
仍将在网格中的所有设备上执行,包括来自其他进程的设备,并且将以全局视图处理跨多个进程展布的数据作为单个数组。
SPMD 模型还要求所有进程中运行的相同多进程pjit()
函数必须按相同顺序运行,但可以与在单个进程中运行的任意操作交替进行。
参数:
- fun(Callable) - 要编译的函数。应为纯函数,因为副作用只能执行一次。其参数和返回值应为数组、标量或其(嵌套的)标准 Python 容器(元组/列表/字典)。由
static_argnums
指示的位置参数可以是任何东西,只要它们是可散列的并且定义了相等操作。静态参数包含在编译缓存键中,这就是为什么必须定义哈希和相等运算符。 - in_shardings –与
fun
参数匹配的 pytree 结构,所有实际参数都替换为资源分配规范。还可以指定一个 pytree 前缀(例如,替换整个子树的一个值),在这种情况下,叶子将广播到该子树的所有值。in_shardings
参数是可选的。JAX 将从输入的jax.Array
推断出分片,并在无法推断出分片时默认复制输入。有效的资源分配规范包括:
Sharding
,它将决定如何分区值。使用网格上下文管理器时,不需要此操作。None
是一种特殊情况,其语义为:
- 如果未提供网格上下文管理器,则 JAX 可以自由选择任何分片方式。对于 in_shardings,JAX 将其标记为复制,但此行为可能在将来更改。对于 out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片方式。
- 如果提供了网格上下文管理器,则
None
将意味着该值将复制到网格的所有设备上。
- 为了向后兼容,in_shardings 仍支持接受
PartitionSpec
。此选项只能与网格上下文管理器一起使用。
PartitionSpec
,最多与分区值的秩相等长的元组。每个元素可以是None
,一个网格轴或网格轴的元组,并指定分配给分区值维度的资源集,与其在规范中的位置匹配。
- 每个维度的大小必须是其分配的总资源数的倍数。
- out_shardings – 类似于
in_shardings
,但指定了函数输出的资源分配。out_shardings
参数是可选的。如果未指定,jax.jit()
将使用 GSPMD 的分片传播来确定如何分片输出。 - static_argnums(int | Sequence [int] | None) –
可选的整数或整数集合,用于指定将哪些位置参数视为静态(编译时常量)。在 Python 中(在追踪期间),仅依赖于静态参数的操作将被常量折叠,因此相应的参数值可以是任何 Python 对象。
静态参数应该是可哈希的,即实现了__hash__
和__eq__
,并且是不可变的。对于这些常量调用 jitted 函数时,使用不同的值将触发重新编译。不是数组或其容器的参数必须标记为静态。
如果未提供static_argnums
,则不将任何参数视为静态。 - static_argnames (str | Iterable[str] | None) – 可选的字符串或字符串集合,指定要视为静态(编译时常量)的命名参数。有关详细信息,请参阅关于
static_argnums
的注释。如果未提供但设置了static_argnums
,则默认基于调用inspect.signature(fun)
查找相应的命名参数。 - donate_argnums (int | Sequence[int] | None) –
指定要“捐赠”给计算的位置参数缓冲区。如果计算结束后不再需要它们,捐赠参数缓冲区是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将您的一个输入缓冲区循环利用来存储结果。您不应重新使用捐赠给计算的缓冲区,如果尝试则 JAX 会引发错误。默认情况下,不会捐赠任何参数缓冲区。
如果既未提供donate_argnums
也未提供donate_argnames
,则不会捐赠任何参数。如果未提供donate_argnums
,但提供了donate_argnames
,或者反之,则 JAX 使用inspect.signature(fun)
查找与donate_argnames
相对应的任何位置参数(或反之)。如果同时提供了donate_argnums
和donate_argnames
,则不使用inspect.signature
,并且只有在donate_argnums
或donate_argnames
中列出的实际参数将被捐赠。
有关缓冲区捐赠的更多详情,请参阅FAQ。 - 捐赠参数名 (str | Iterable[str] | None) – 一个可选的字符串或字符串集合,指定哪些命名参数将捐赠给计算。有关详细信息,请参见对
donate_argnums
的注释。如果未提供但设置了donate_argnums
,则默认基于调用inspect.signature(fun)
查找相应的命名参数。 - 保留未使用 (bool) – 如果为 False(默认值),JAX 确定 fun 未使用的参数 可能 会从生成的编译后 XLA 可执行文件中删除。这些参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会剪枝未使用的参数。
- 设备 (Device | None) – 此参数已弃用。请在将参数传递给 jit 之前将您需要的设备置于其上。可选,jit 函数将在其上运行的设备。 (可通过
jax.devices()
获取可用设备。)默认情况下,继承自 XLA 的 DeviceAssignment 逻辑,并通常使用jax.devices()[0]
。 - 后端 (str | None) – 此参数已弃用。请在将参数传递给 jit 之前将您需要的后端置于其前。可选,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。 - 内联 (bool)
- 抽象轴 (Any | None)
返回:
fun
的包装版本,专为即时编译而设,并在每次调用点根据可用的网格自动分区。
返回类型:
JitWrapped
例如,卷积运算符可以通过单个 pjit()
应用自动分区到任意一组设备上:
>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import Mesh, PartitionSpec >>> from jax.experimental.pjit import pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_shardings=None, out_shardings=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2\. 4\. 6\. 8\. 10\. 12\. 10\. ]
jax.experimental.sparse 模块
jax.experimental.sparse
模块包括对 JAX 中稀疏矩阵操作的实验性支持。它正在积极开发中,API 可能会更改。主要提供的接口是 BCOO
稀疏数组类型和 sparsify()
变换。
批量坐标(BCOO)稀疏矩阵
JAX 中目前主要的高级稀疏对象是 BCOO
,或者 批量坐标 稀疏数组,它提供与 JAX 变换兼容的压缩存储格式,特别是 JIT(例如 jax.jit()
)、批处理(例如 jax.vmap()
)和自动微分(例如 jax.grad()
)。
下面是一个从稠密数组创建稀疏数组的例子:
>>> from jax.experimental import sparse >>> import jax.numpy as jnp >>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.], ... [3., 0., 0., 0.], ... [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp BCOO(float32[3, 4], nse=4)
使用 todense()
方法转换回稠密数组:
>>> M_sp.todense() Array([[0., 1., 0., 2.], [3., 0., 0., 0.], [0., 0., 4., 0.]], dtype=float32)
BCOO 格式是标准 COO 格式的一种略微修改版本,密集表示可以在 data
和 indices
属性中看到:
>>> M_sp.data # Explicitly stored data Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data Array([[0, 1], [0, 3], [1, 0], [2, 2]], dtype=int32)
BCOO 对象具有类似数组的属性,以及稀疏特定的属性:
>>> M_sp.ndim 2
>>> M_sp.shape (3, 4)
>>> M_sp.dtype dtype('float32')
>>> M_sp.nse # "number of specified elements" 4
BCOO 对象还实现了许多类数组的方法,允许您直接在 jax 程序中使用它们。例如,在这里我们计算转置矩阵向量乘积:
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version Array([18., 3., 20., 6.], dtype=float32)
BCOO 对象设计成与 JAX 变换兼容,包括 jax.jit()
、jax.vmap()
、jax.grad()
等。例如:
>>> from jax import grad, jit
>>> def f(y): ... return (M_sp.T @ y).sum() ... >>> jit(grad(f))(y) Array([3., 3., 4.], dtype=float32)
注意,正常情况下,jax.numpy
和 jax.lax
函数不知道如何处理稀疏矩阵,因此尝试计算诸如 jnp.dot(M_sp.T, y)
的东西将导致错误(但请参见下一节)。
稀疏化变换
JAX 稀疏实现的一个主要目标是提供一种无缝从密集到稀疏计算切换的方法,而无需修改密集实现。这个稀疏实验通过 sparsify()
变换实现了这一目标。
考虑这个函数,它从矩阵和向量输入计算更复杂的结果:
>>> def f(M, v): ... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1 ... >>> f(M, y) Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
如果我们直接传递稀疏矩阵到这个函数,将会导致错误,因为 jnp
函数不识别稀疏输入。然而,使用 sparsify()
,我们得到一个接受稀疏矩阵的函数版本:
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y) Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
sparsify()
支持包括许多最常见的原语,例如:
- 广义(批量)矩阵乘积和爱因斯坦求和(
dot_general_p
) - 保持零的逐元素二元操作(例如
add_p
、mul_p
等) - 保持零的逐元素一元操作(例如
abs_p
、jax.lax.neg_p
等) - 求和约简(
reduce_sum_p
) - 通用索引操作(
slice_p
、lax.dynamic_slice_p
、lax.gather_p
) - 连接和堆叠(
concatenate_p
) - 转置和重塑(
transpose_p
、reshape_p
、squeeze_p
、broadcast_in_dim_p
) - 一些高阶函数(
cond_p
、while_p
、scan_p
) - 一些简单的 1D 卷积(
conv_general_dilated_p
)
几乎任何 jax.numpy
函数在 sparsify
转换中都可以使用,以操作稀疏数组。这组基元足以支持相对复杂的稀疏工作流程,如下一节所示。
示例:稀疏逻辑回归
作为更复杂稀疏工作流的示例,让我们考虑在 JAX 中实现的简单逻辑回归。请注意,以下实现与稀疏性无关:
>>> import functools >>> from sklearn.datasets import make_classification >>> from jax.scipy import optimize
>>> def sigmoid(x): ... return 0.5 * (jnp.tanh(x / 2) + 1) ... >>> def y_model(params, X): ... return sigmoid(jnp.dot(X, params[1:]) + params[0]) ... >>> def loss(params, X, y): ... y_hat = y_model(params, X) ... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)) ... >>> def fit_logreg(X, y): ... params = jnp.zeros(X.shape[1] + 1) ... result = optimize.minimize(functools.partial(loss, X=X, y=y), ... x0=params, method='BFGS') ... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701) >>> params_dense = fit_logreg(X, y) >>> print(params_dense) [-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008 -0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844 0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087 -0.67060554 0.03139788 -0.05359547]
这会返回密集逻辑回归问题的最佳拟合参数。要在稀疏数据上拟合相同的模型,我们可以应用sparsify()
转换:
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input >>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function >>> params_sparse = fit_logreg_sp(Xsp, y) >>> print(params_sparse) [-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566 -0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753 0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593 -0.670236 0.03132951 -0.05356663]
JAX 中文文档(十五)(5)https://developer.aliyun.com/article/1559772