JAX 中文文档(二)(5)

简介: JAX 中文文档(二)

JAX 中文文档(二)(4)https://developer.aliyun.com/article/1559671


分片计算介绍

原文:jax.readthedocs.io/en/latest/sharded-computation.html

本教程介绍了 JAX 中单程序多数据(SPMD)代码的设备并行性。SPMD  是一种并行技术,可以在不同设备上并行运行相同的计算,比如神经网络的前向传播,可以在不同的输入数据上(比如批量中的不同输入)并行运行在不同的设备上,比如几个  GPU 或 Google TPU 上。

本教程涵盖了三种并行计算模式:

  • 通过jax.jit()自动并行化:编译器选择最佳的计算策略(也被称为“编译器接管”)。
  • 使用jax.jit()jax.lax.with_sharding_constraint()半自动并行化
  • 使用jax.experimental.shard_map.shard_map()进行全手动并行化:shard_map可以实现每个设备的代码和显式的通信集合

使用这些 SPMD 的思路,您可以将为一个设备编写的函数转换为可以在多个设备上并行运行的函数。

如果您在 Google Colab 笔记本中运行这些示例,请确保您的硬件加速器是最新的 Google TPU,方法是检查笔记本设置:Runtime > Change runtime type > Hardware accelerator > TPU v2(提供八个可用设备)。

import jax
jax.devices() 
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] 

关键概念:数据分片

下面列出的所有分布式计算方法的关键是数据分片的概念,描述了如何在可用设备上布置数据。

JAX 如何理解数据在各个设备上的布局?JAX 的数据类型,jax.Array不可变数组数据结构,代表了在一个或多个设备上具有物理存储的数组,并且有助于使并行化成为 JAX 的核心特性。jax.Array对象是专为分布式数据和计算而设计的。每个jax.Array都有一个关联的jax.sharding.Sharding对象,描述了每个全局设备所需的全局数据的分片情况。当您从头开始创建jax.Array时,您还需要创建它的Sharding

在简单的情况下,数组被分片在单个设备上,如下所示:

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices() 
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)} 
arr.sharding 
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)) 

若要更直观地表示存储布局,jax.debug模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding()显示了数组如何存储在单个设备的内存中:

jax.debug.visualize_array_sharding(arr) 
TPU 0 

要创建具有非平凡分片的数组,可以为数组定义一个jax.sharding规范,并将其传递给jax.device_put()

在这里,定义一个NamedSharding,它指定了一个带有命名轴的 N 维设备网格,其中jax.sharding.Mesh允许精确的设备放置:

# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding) 
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y')) 

将该Sharding对象传递给jax.device_put(),就可以获得一个分片数组:

arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded) 
[[ 0\.  1\.  2\.  3\.  4\.  5\.  6\.  7.]
 [ 8\.  9\. 10\. 11\. 12\. 13\. 14\. 15.]
 [16\. 17\. 18\. 19\. 20\. 21\. 22\. 23.]
 [24\. 25\. 26\. 27\. 28\. 29\. 30\. 31.]] 
TPU 0   TPU 1       TPU 2    TPU 3 
 TPU 6   TPU 7       TPU 4    TPU 5 

这里的设备编号并不按数字顺序排列,因为网格反映了设备的环形拓扑结构。

1. 通过jit实现自动并行化

一旦您有了分片数据,最简单的并行计算方法就是将数据简单地传递给jax.jit()编译的函数!在 JAX 中,您只需指定希望代码的输入和输出如何分区,编译器将会自动处理:1)内部所有内容的分区;2)跨设备的通信的编译。

jit背后的 XLA 编译器包含了优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法可以归结为计算跟随数据

为了演示 JAX 中自动并行化的工作原理,下面是一个使用jax.jit()装饰的延迟执行函数的示例:这是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出也以相同的方式进行分片:

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding) 
shardings match: True 

随着计算变得更加复杂,编译器会决定如何最佳地传播数据的分片。

在这里,您沿着x的主轴求和,并可视化结果值如何存储在多个设备上(使用jax.debug.visualize_array_sharding()):

@jax.jit
def f_contract(x):
  return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result) 
TPU 0,6 TPU 1,7  TPU 2,4 TPU 3,5 
[48\. 52\. 56\. 60\. 64\. 68\. 72\. 76.] 

结果部分复制:即数组的前两个元素复制到设备06,第二个到17,依此类推。

2. 使用约束进行半自动分片

如果您希望在特定计算中对使用的分片进行一些控制,JAX 提供了with_sharding_constraint()函数。您可以使用jax.lax.with_sharding_constraint()(而不是jax.device_put())与jax.jit()一起更精确地控制编译器如何约束中间值和输出的分布。

例如,假设在上面的f_contract中,您希望输出不是部分复制,而是完全在八个设备上进行分片:

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  # mesh = jax.create_mesh((8,), 'x')
  devices = mesh_utils.create_device_mesh(8)
  mesh = jax.sharding.Mesh(devices, 'x')
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)
result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result) 
TPU 0  TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4  TPU 5 
[48\. 52\. 56\. 60\. 64\. 68\. 72\. 76.] 

这将为您提供具有所需输出分片的函数。

3. 使用shard_map进行手动并行处理

在上述自动并行化方法中,您可以编写一个函数,就像在操作完整数据集一样,jit将会将该计算分配到多个设备上执行。相比之下,使用jax.experimental.shard_map.shard_map(),您需要编写处理单个数据片段的函数,而shard_map将构建完整的函数。

shard_map的工作方式是在设备mesh上映射函数(shard_map在 shards 上进行映射)。在下面的示例中:

  • 与以往一样,jax.sharding.Mesh允许精确的设备放置,使用轴名称参数来表示逻辑和物理轴名称。
  • in_specs参数确定了分片大小。out_specs参数标识了如何将块重新组装在一起。

注意: 如果需要,jax.experimental.shard_map.shard_map()代码可以在jax.jit()内部工作。

from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))
arr = jnp.arange(32)
f_elementwise_sharded(arr) 
Array([ 1\.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32) 

您编写的函数只“看到”数据的单个批次,可以通过打印设备本地形状来检查:

x = jnp.arange(32)
print(f"global shape: {x.shape=}")
def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) 
global shape: x.shape=(32,)
device local shape: x.shape=(4,) 

因为每个函数只“看到”数据的设备本地部分,这意味着像聚合的函数需要额外的思考。

例如,这是jax.numpy.sum()shard_map的示例:

def f(x):
  return jnp.sum(x, keepdims=True)
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) 
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32) 

您的函数f分别在每个分片上运行,并且结果的总和反映了这一点。

如果要跨分片进行求和,您需要显式请求,使用像jax.lax.psum()这样的集合操作:

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) 
Array(496, dtype=int32) 

因为输出不再具有分片维度,所以设置out_specs=P()(请记住,out_specs参数标识如何在shard_map中将块重新组装在一起)。

比较这三种方法

在我们记忆中掌握这些概念后,让我们比较简单神经网络层的三种方法。

首先像这样定义您的规范函数:

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias) 
import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))
layer(x, weights, bias) 
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32) 

您可以使用jax.jit()自动以分布式方式运行此操作,并传递适当分片的数据。

如果您以相同的方式分片xweights的主轴,则矩阵乘法将自动并行发生:

P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias) 
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32) 

或者,您可以在函数中使用jax.lax.with_sharding_constraint()自动分发未分片的输入:

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)
layer_auto(x, weights, bias)  # pass in unsharded inputs 
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32) 

最后,您可以使用shard_map以相同的方式执行此操作,使用jax.lax.psum()指示矩阵乘积所需的跨分片集合:

from functools import partial
@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
layer_sharded(x, weights, bias) 
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32) 

下一步

本教程简要介绍了在 JAX 中分片和并行计算的概念。

要深入了解每种 SPMD 方法,请查看以下文档:

  • 分布式数组和自动并行化
  • 使用shard_map进行 SPMD 多设备并行性
相关文章
|
6月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
54 0
JAX 中文文档(六)(1)
|
6月前
|
安全 编译器 TensorFlow
JAX 中文文档(四)(5)
JAX 中文文档(四)
48 0
|
6月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
86 0
|
6月前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
73 0
|
6月前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
49 0
|
6月前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
34 0
|
6月前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
45 0
|
6月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
190 0
|
6月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
57 0
|
6月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
47 0