JAX 中文文档(十三)(2)

简介: JAX 中文文档(十三)

JAX 中文文档(十三)(1)https://developer.aliyun.com/article/1559741


切换到 jax.Array 后对于主机本地输入的 pjit 有破坏性变更。

如果您完全使用 GDA 参数作为 pjit 的输入,则可以跳过此部分! 🎉

启用jax.Array后,所有传递给pjit的输入必须是全局形状的。这是与之前行为不兼容的变化,之前的pjit会将进程本地的参数连接成一个全局值;现在不再进行此连接。

为什么我们要进行这个突破性的变化?现在每个数组都明确说明了它的本地分片如何适合全局整体,而不是留下隐含的情况。更明确的表示方式还可以解锁额外的灵活性,例如在某些 TPU 模型上可以提高效率的非连续网格使用pjit

在启用jax.Array时,运行多进程 pjit 计算并在传递主机本地输入时可能会导致类似以下错误:

示例:

Mesh = {'x': 2, 'y': 2, 'z': 2} 和主机本地输入形状 == (4,) 以及pspec = P(('x', 'y', 'z'))

因为pjit不会将主机本地形状提升为全局形状,所以您会收到以下错误:

注意:只有当您的主机本地形状小于网格的形状时,才会看到此错误。

ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4 

错误出现是因为当维度0上的值为4时,无法将其分片成 8 份。

如果你仍然将主机本地输入传递给pjit,如何迁移?我们提供了过渡 API 来帮助您迁移:

注意:如果您在单进程上运行pjit计算,则不需要这些实用程序。

from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
    local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
                      out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
    global_outputs, mesh, out_pspecs) 

host_local_array_to_global_array是一种类型转换,它查看具有仅本地分片的值,并将其本地形状更改为在更改之前如果传递该值pjit会假定的形状。

支持完全复制的输入,即每个进程上具有相同形状,并且in_axis_resourcesP(None)的情况。在这种情况下,您无需使用host_local_array_to_global_array,因为形状已经是全局的。

key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
    local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
                  out_shardings=...)(key, global_inp) 

FROM_GDAjax.Array

如果你在in_axis_resources参数中使用FROM_GDA来传递给pjit,那么在使用jax.Array时,无需向in_axis_resources传递任何内容,因为jax.Array将遵循计算遵循分片的语义。

例如:

pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...) 

如果你的输入中混合了PartitionSpecsFROM_GDA,例如 numpy 数组等,则使用host_local_array_to_global_array将它们转换为jax.Array

例如:

如果你有这样的情况:

pjitted_f = pjit(
    f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
    out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2) 

然后您可以将其替换为:

pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
    (np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4) 

live_buffers替换为live_arrays

jax Device上的live_buffers属性已被弃用。请改用与jax.Array兼容的jax.live_arrays()

处理向pjit传递的主机本地输入,例如批次等。

如果在多进程环境中向pjit传递主机本地输入,请使用multihost_utils.host_local_array_to_global_array将批次转换为全局jax.Array,然后将其传递给pjit

这种主机本地输入最常见的例子是输入数据批次

这对任何主机本地输入都有效(不仅仅是输入数据批次)。

from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
    batch, mesh, batch_partition_spec) 

关于这种变化以及更多示例,请参阅上面的 pjit 部分。

RecursionError:递归调用 jit 时发生的错误。

当你的代码的某部分禁用了 jax.Array,然后你仅在其他部分启用它时会出现这种情况。例如,如果你使用某些第三方代码,该代码已禁用了 jax.Array 并从该库获得一个 DeviceArray,然后在你的库中启用 jax.Array 并将该 DeviceArray 传递给 JAX 函数,就会导致 RecursionError。

jax.Array 默认启用时,所有库都返回 jax.Array,除非显式禁用它,这个错误就应该消失。

异步调度

原文:jax.readthedocs.io/en/latest/async_dispatch.html

JAX 使用异步调度来隐藏 Python 的开销。考虑以下程序:

>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.  
Array([[258.01971436, 249.64862061, 257.13372803, ...,
 236.67948914, 250.68939209, 241.36853027],
 [265.65979004, 256.28912354, 262.18252563, ...,
 242.03181458, 256.16757202, 252.44122314],
 [262.38916016, 255.72747803, 261.23059082, ...,
 240.83563232, 255.41094971, 249.62471008],
 ...,
 [259.15814209, 253.09197998, 257.72174072, ...,
 242.23876953, 250.72680664, 247.16642761],
 [271.22662354, 261.91204834, 265.33398438, ...,
 248.26651001, 262.05389404, 261.33700562],
 [257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
 248.62597656, 243.22348022]], dtype=float32) 

当执行诸如 jnp.dot(x, x) 这样的操作时,JAX 不会等待操作完成再将控制返回给 Python 程序。相反,JAX 返回一个 jax.Array 值,它是一个未来的值,即将来在加速设备上生成但不一定立即可用的值。我们可以检查 jax.Array 的形状或类型,而无需等待生成它的计算完成,并且甚至可以将其传递给另一个 JAX 计算,正如我们在此处执行加法操作一样。只有当我们实际从主机检查数组的值时,例如通过打印它或将其转换为普通的 numpy.ndarray,JAX 才会强制 Python 代码等待计算完成。

异步调度非常有用,因为它允许 Python 代码在加速设备之前“超前运行”,从而避免 Python 代码进入关键路径。只要 Python  代码将工作快速地加入设备的队列,比它执行得更快,并且只要 Python 代码实际上不需要检查主机上的计算输出,那么 Python  程序就可以加入任意量的工作并避免让加速器等待。

异步调度对微基准测试有一个稍显意外的影响。

>>> %time jnp.dot(x, x)  
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
 233.67948914, 247.68939209, 238.36853027],
 [262.65979004, 253.28910828, 259.18252563, ...,
 239.03181458, 253.16757202, 249.44122314],
 [259.38916016, 252.72747803, 258.23059082, ...,
 237.83563232, 252.41094971, 246.62471008],
 ...,
 [256.15814209, 250.09197998, 254.72172546, ...,
 239.23876953, 247.72680664, 244.16642761],
 [268.22662354, 258.91204834, 262.33398438, ...,
 245.26651001, 259.05389404, 258.33700562],
 [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
 245.62597656, 240.22348022]], dtype=float32) 

对于在 CPU 上进行的 1000x1000 矩阵乘法来说,269µs  的时间是一个令人惊讶地小的时间!然而,事实证明异步调度在误导我们,我们并没有计时矩阵乘法的执行,而是调度工作的时间。要测量操作的真正成本,我们必须要么在主机上读取值(例如,将其转换为普通的主机端  numpy 数组),要么在 jax.Array 值上使用 block_until_ready() 方法,等待生成它的计算完成。

>>> %time np.asarray(jnp.dot(x, x))  
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
 238.36853],
 [262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
 249.44122],
 [259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
 246.62471],
 ...,
 [256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
 244.16643],
 [268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
 258.337  ],
 [254.16135, 251.75433, 256.083  , ..., 238.59848, 245.62598,
 240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()  
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
 233.67948914, 247.68939209, 238.36853027],
 [262.65979004, 253.28910828, 259.18252563, ...,
 239.03181458, 253.16757202, 249.44122314],
 [259.38916016, 252.72747803, 258.23059082, ...,
 237.83563232, 252.41094971, 246.62471008],
 ...,
 [256.15814209, 250.09197998, 254.72172546, ...,
 239.23876953, 247.72680664, 244.16642761],
 [268.22662354, 258.91204834, 262.33398438, ...,
 245.26651001, 259.05389404, 258.33700562],
 [254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
 245.62597656, 240.22348022]], dtype=float32) 

在不将结果转移到 Python 的情况下进行阻塞通常更快,通常是编写计算时间微基准测试时的最佳选择。

并发

JAX 并发

JAX 对 Python 并发的支持有限。

客户端可以从不同的 Python 线程并发调用 JAX API(例如,jit()grad())。

不允许同时从多个线程并发地操作 JAX 追踪值。换句话说,虽然可以从多个线程调用使用 JAX 追踪的函数(例如 jit()),但不得使用线程来操作传递给 jit() 的函数 f 实现内部的 JAX 值。如果这样做,最有可能的结果是 JAX 报告一个神秘的错误。


JAX 中文文档(十三)(3)https://developer.aliyun.com/article/1559743

相关文章
|
4月前
|
机器学习/深度学习 编译器 API
JAX 中文文档(十三)(4)
JAX 中文文档(十三)
56 2
|
4月前
|
算法 Serverless 索引
JAX 中文文档(十三)(5)
JAX 中文文档(十三)
29 1
|
4月前
|
缓存 TensorFlow 算法框架/工具
JAX 中文文档(十三)(3)
JAX 中文文档(十三)
74 1
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
机器学习/深度学习 Shell API
JAX 中文文档(十三)(1)
JAX 中文文档(十三)
41 0
|
4月前
|
机器学习/深度学习 分布式计算 程序员
JAX 中文文档(十一)(1)
JAX 中文文档(十一)
35 0