JAX 中文文档(七)(2)

简介: JAX 中文文档(七)

JAX 中文文档(七)(1)https://developer.aliyun.com/article/1559695


all_gather

另一个基本操作是沿轴收集数组片段,以便每个函数应用程序在该轴上都有数据的完整副本:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
 [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2] 

打印显示,每个函数应用程序再次以其自己的 x_block 参数值块的一个片段开始。在 all_gather 后,它们具有一个通过连接 x_block 值计算得到的共同值。

(请注意,我们实际上不能在此处设置 out_specs=P()。由于与自动微分相关的技术原因,我们认为 all_gather 的输出不保证在不同设备上不变。如果我们希望它保证不变,我们可以使用 jax.lax.all_gather_invariant,或者在这种情况下,我们可以避免在函数体中执行 all_gather,而是只使用 out_specs=P('i') 来执行连接。)

tiled=False(默认情况下)时,结果沿新轴堆叠而不是连接:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
  print('AFTER:\n', y_block)
  return y_block
y = f5(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
 [9]
 [5]
 [2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
 [9]
 [5]
 [2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
 [9]
 [5]
 [2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
 [9]
 [5]
 [2]]
FINAL RESULT:
 [[3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]] 

我们可以为 all_gather 编写 collective_ref 引用语义函数:

def all_gather_ref(_, x_blocks, *, tiled=False):
  combine = jnp.concatenate if tiled else jnp.stack
  return [combine(x_blocks)] * len(x_blocks) 

在深度学习中,我们可以在完全分片数据并行性(FSDP)中对参数使用 all_gather

psum_scatter

jax.lax.psum_scatter 集合操作有点不那么直观。它类似于 psum,但每个函数实例只获得结果的一个分片:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
 [22 20 12 17] 

如打印所示,每个结果的 y_block 比参数 x_block 的大小要小,与 psum 不同。此外,与 psum 相比,这里每个 y_block 只表示函数实例的 x_block 总和的一个片段。 (尽管每个函数实例只获得总和的一个分片,但最终输出 ypsum 示例中的相同,因为我们在这里使用 out_specs=P('i') 来连接每个函数实例的输出。)

在计算的值方面,collective_ref 参考实现可能如下所示:

def psum_scatter_ref(i, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  tot = sum(x_blocks)
  if tiled:
    tot = tot.reshape(axis_size, -1, *tot.shape[1:])  # split leading axis
  return [tot[i] for i in range(tot.shape[0])] 

语义参考实现中未捕获,但 psum_scatter 很有用,因为这些结果可以比完整的 psum 更高效地计算,通信量更少。事实上,可以将 psum_scatter 看作是 psum 的“前半部分,即 all_gather”的一种方式。也就是说,实现 psum 的一种方式是:

def psum(x, axis_name):
  summed_chunk = jax.lax.psum_scatter(x, axis_name)
  return jax.lax.all_gather(summed_chunk, axis_name) 

实际上,这种实现经常在 TPU 和 GPU 上使用!

psum_scatter需要约一半通信量的原因在ppermute部分有所体现。

另一个直觉是,我们可以使用psum_scatter来实现分布式矩阵乘法,其中输入和输出在相同的轴上分片。在机器学习中,psum_scatter可以用于张量并行矩阵乘法或完全分片数据并行梯度累积,如下例所示。

ppermute

jax.lax.ppermute集合提供了实例函数相互发送数据的最直接方式。给定一个网格轴和一个表示沿着该网格轴的索引的(source_index, destination_index)对列表,ppermute将其参数值从每个源函数实例发送到每个目标:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
  sz = jax.lax.psum(1, 'i')
  print('BEFORE:\n', x_block)
  y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
  print('AFTER:\n', y_block)
  return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
 [6 7 0 1 2 3 4 5] 

在这种情况下,仅有两个函数实例,每个实例的y_block值是另一个实例的x_block值。

源索引和目标索引不能重复。如果一个索引未出现为目标,则相应函数实例结果的值为零数组。

一个collective_ref的参考实现可能是这样的:

def ppermute_ref(i, x_blocks, perm):
  results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
  for src, dst in perm:
    results[dst] = x_blocks[src]
  return results 

其他集合操作可以通过使用ppermute来实现,其中每个函数只向其邻居传递数据,从而在总通信量方面实现高效。例如,我们可以用这种方式实现psum_scatter,通过一系列ppermute和本地加法:

或者,举个数值示例:

直观地说,每次迭代时,每个函数实例都将前一次迭代接收到的值“上送”,并在本次迭代中减少(添加)它接收到的值。在代码中,可能看起来像这样:

def psum_scatter(x, axis_name, *, tiled=False):
  size = jax.lax.psum(1, axis_name)
  idx = jax.lax.axis_index(axis_name)  # function instance index along axis_name
  if tiled:
    x = x.reshape(size, -1, *x.shape[1:])  # split leading axis
  shift = partial(jax.lax.ppermute, axis_name=axis_name,
                  perm=[(i, (i - 1) % size) for i in range(size)])
  for i in range(1, size):
    update = shift(x[(idx + i) % size])
    x = x.at[(idx + i + 1) % size].add(update)
  return x[idx] 
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
  print('BEFORE:\n', x_block)
  y_block = psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
 [22 20 12 17] 

在 TPU 上,有更高维度的算法变体来利用多向双向物理网格轴。

注意,psum_scatterall_gather的转置。事实上,实现all_gather的一种方式是使用ppermute的逆过程:

在深度学习中,当实现 SPMD 管道并行时,我们可能会使用ppermute,其中我们沿着深度将网络分割成阶段并并行评估阶段的应用。或者,当并行化卷积层的评估时,我们可能会使用ppermute,其中我们在空间轴上分片,因此设备必须相互通信“halos”。或者在张量并行矩阵乘法的幕后使用它。

all_to_all

最后一个集合操作是all_to_all,它本质上是沿一个位置轴和一个跨设备轴进行的块矩阵转置操作:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
                               tiled=True)
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
 [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2] 

split_axis 参数指示应该在网格轴上分片和分区的位置轴。concat_axis 参数指示应该在通信结果应该被连接或堆叠的轴。

tiled=False(默认情况下),split_axis 轴的大小必须等于命名为 axis_name 的网格轴的大小,并且在位置 concat_axis 创建一个新的该大小的轴用于堆叠结果。当 tiled=True 时,split_axis 轴的大小只需可以被网格轴的大小整除,结果沿现有轴 concat_axis 连接。

split_axis=0concat_axis=0 时,collective_ref 引用语义可能如下:

def all_to_all_ref(_, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  if tiled:
    splits = [jnp.array_split(x, axis_size) for x in x_blocks]
    return [jnp.concatenate(s) for s in zip(*splits)]
  else:
    splits = [list(x) for x in x_blocks]
    return [jnp.stack(s) for s in zip(*splits)] 

在深度学习中,我们可能在专家混合路由中使用 all_to_all,我们首先根据它们应该去的专家对我们的本地批次的示例进行排序,然后应用 all_to_all 重新分发示例到专家。


JAX 中文文档(七)(3)https://developer.aliyun.com/article/1559699

相关文章
|
4月前
|
安全 编译器 TensorFlow
JAX 中文文档(四)(5)
JAX 中文文档(四)
33 0
|
4月前
|
机器学习/深度学习 异构计算 AI芯片
JAX 中文文档(七)(4)
JAX 中文文档(七)
31 0
|
4月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
37 0
|
4月前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
31 0
|
4月前
|
编译器 API 异构计算
JAX 中文文档(一)(2)
JAX 中文文档(一)
60 0
|
4月前
|
编译器 测试技术 API
JAX 中文文档(四)(4)
JAX 中文文档(四)
39 0
|
4月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
24 0
|
4月前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
28 0
|
4月前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
48 0
|
4月前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
26 0