JAX 中文文档(七)(1)https://developer.aliyun.com/article/1559695
all_gather
另一个基本操作是沿轴收集数组片段,以便每个函数应用程序在该轴上都有数据的完整副本:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f4(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=True) print('AFTER:\n', y_block) return y_block x = jnp.array([3, 9, 5, 2]) y = f4(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [9] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 9 5 2] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [3 9 5 2] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [3 9 5 2] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [3 9 5 2] FINAL RESULT: [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
打印显示,每个函数应用程序再次以其自己的 x_block
参数值块的一个片段开始。在 all_gather
后,它们具有一个通过连接 x_block
值计算得到的共同值。
(请注意,我们实际上不能在此处设置 out_specs=P()
。由于与自动微分相关的技术原因,我们认为 all_gather
的输出不保证在不同设备上不变。如果我们希望它保证不变,我们可以使用 jax.lax.all_gather_invariant
,或者在这种情况下,我们可以避免在函数体中执行 all_gather
,而是只使用 out_specs=P('i')
来执行连接。)
当 tiled=False
(默认情况下)时,结果沿新轴堆叠而不是连接:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f5(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=False) print('AFTER:\n', y_block) return y_block y = f5(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [9] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [[3] [9] [5] [2]] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [[3] [9] [5] [2]] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [[3] [9] [5] [2]] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [[3] [9] [5] [2]] FINAL RESULT: [[3] [9] [5] [2] [3] [9] [5] [2] [3] [9] [5] [2] [3] [9] [5] [2]]
我们可以为 all_gather
编写 collective_ref
引用语义函数:
def all_gather_ref(_, x_blocks, *, tiled=False): combine = jnp.concatenate if tiled else jnp.stack return [combine(x_blocks)] * len(x_blocks)
在深度学习中,我们可以在完全分片数据并行性(FSDP)中对参数使用 all_gather
。
psum_scatter
jax.lax.psum_scatter
集合操作有点不那么直观。它类似于 psum
,但每个函数实例只获得结果的一个分片:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f6(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) print('AFTER:\n', y_block) return y_block x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) y = f6(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 1 4 1] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [5 9 2 6] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5 3 5 8] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [9 7 1 2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [22] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [20] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [12] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [17] FINAL RESULT: [22 20 12 17]
如打印所示,每个结果的 y_block
比参数 x_block
的大小要小,与 psum
不同。此外,与 psum
相比,这里每个 y_block
只表示函数实例的 x_block
总和的一个片段。 (尽管每个函数实例只获得总和的一个分片,但最终输出 y
与 psum
示例中的相同,因为我们在这里使用 out_specs=P('i')
来连接每个函数实例的输出。)
在计算的值方面,collective_ref
参考实现可能如下所示:
def psum_scatter_ref(i, x_blocks, *, tiled=False): axis_size = len(x_blocks) tot = sum(x_blocks) if tiled: tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis return [tot[i] for i in range(tot.shape[0])]
语义参考实现中未捕获,但 psum_scatter
很有用,因为这些结果可以比完整的 psum
更高效地计算,通信量更少。事实上,可以将 psum_scatter
看作是 psum
的“前半部分,即 all_gather
”的一种方式。也就是说,实现 psum
的一种方式是:
def psum(x, axis_name): summed_chunk = jax.lax.psum_scatter(x, axis_name) return jax.lax.all_gather(summed_chunk, axis_name)
实际上,这种实现经常在 TPU 和 GPU 上使用!
psum_scatter
需要约一半通信量的原因在ppermute
部分有所体现。
另一个直觉是,我们可以使用psum_scatter
来实现分布式矩阵乘法,其中输入和输出在相同的轴上分片。在机器学习中,psum_scatter
可以用于张量并行矩阵乘法或完全分片数据并行梯度累积,如下例所示。
ppermute
jax.lax.ppermute
集合提供了实例函数相互发送数据的最直接方式。给定一个网格轴和一个表示沿着该网格轴的索引的(source_index, destination_index)
对列表,ppermute
将其参数值从每个源函数实例发送到每个目标:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f7(x_block): sz = jax.lax.psum(1, 'i') print('BEFORE:\n', x_block) y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)]) print('AFTER:\n', y_block) return y_block y = f7(jnp.arange(8)) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [0 1] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [2 3] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [4 5] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [6 7] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [6 7] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [0 1] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [2 3] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [4 5] FINAL RESULT: [6 7 0 1 2 3 4 5]
在这种情况下,仅有两个函数实例,每个实例的y_block
值是另一个实例的x_block
值。
源索引和目标索引不能重复。如果一个索引未出现为目标,则相应函数实例结果的值为零数组。
一个collective_ref
的参考实现可能是这样的:
def ppermute_ref(i, x_blocks, perm): results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks) for src, dst in perm: results[dst] = x_blocks[src] return results
其他集合操作可以通过使用ppermute
来实现,其中每个函数只向其邻居传递数据,从而在总通信量方面实现高效。例如,我们可以用这种方式实现psum_scatter
,通过一系列ppermute
和本地加法:
或者,举个数值示例:
直观地说,每次迭代时,每个函数实例都将前一次迭代接收到的值“上送”,并在本次迭代中减少(添加)它接收到的值。在代码中,可能看起来像这样:
def psum_scatter(x, axis_name, *, tiled=False): size = jax.lax.psum(1, axis_name) idx = jax.lax.axis_index(axis_name) # function instance index along axis_name if tiled: x = x.reshape(size, -1, *x.shape[1:]) # split leading axis shift = partial(jax.lax.ppermute, axis_name=axis_name, perm=[(i, (i - 1) % size) for i in range(size)]) for i in range(1, size): update = shift(x[(idx + i) % size]) x = x.at[(idx + i + 1) % size].add(update) return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f8(x_block): print('BEFORE:\n', x_block) y_block = psum_scatter(x_block, 'i', tiled=True) print('AFTER:\n', y_block) return y_block x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) y = f8(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 1 4 1] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [5 9 2 6] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5 3 5 8] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [9 7 1 2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [22] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [20] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [12] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [17] FINAL RESULT: [22 20 12 17]
在 TPU 上,有更高维度的算法变体来利用多向双向物理网格轴。
注意,psum_scatter
是all_gather
的转置。事实上,实现all_gather
的一种方式是使用ppermute
的逆过程:
在深度学习中,当实现 SPMD 管道并行时,我们可能会使用ppermute
,其中我们沿着深度将网络分割成阶段并并行评估阶段的应用。或者,当并行化卷积层的评估时,我们可能会使用ppermute
,其中我们在空间轴上分片,因此设备必须相互通信“halos”。或者在张量并行矩阵乘法的幕后使用它。
all_to_all
最后一个集合操作是all_to_all
,它本质上是沿一个位置轴和一个跨设备轴进行的块矩阵转置操作:
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f9(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0, tiled=True) print('AFTER:\n', y_block) return y_block x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) y = f9(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 1 4 1] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [5 9 2 6] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5 3 5 8] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [9 7 1 2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 5 5 9] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [1 9 3 7] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [4 2 5 1] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [1 6 8 2] FINAL RESULT: [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
split_axis
参数指示应该在网格轴上分片和分区的位置轴。concat_axis
参数指示应该在通信结果应该被连接或堆叠的轴。
当 tiled=False
(默认情况下),split_axis
轴的大小必须等于命名为 axis_name
的网格轴的大小,并且在位置 concat_axis
创建一个新的该大小的轴用于堆叠结果。当 tiled=True
时,split_axis
轴的大小只需可以被网格轴的大小整除,结果沿现有轴 concat_axis
连接。
当 split_axis=0
和 concat_axis=0
时,collective_ref
引用语义可能如下:
def all_to_all_ref(_, x_blocks, *, tiled=False): axis_size = len(x_blocks) if tiled: splits = [jnp.array_split(x, axis_size) for x in x_blocks] return [jnp.concatenate(s) for s in zip(*splits)] else: splits = [list(x) for x in x_blocks] return [jnp.stack(s) for s in zip(*splits)]
在深度学习中,我们可能在专家混合路由中使用 all_to_all
,我们首先根据它们应该去的专家对我们的本地批次的示例进行排序,然后应用 all_to_all
重新分发示例到专家。
JAX 中文文档(七)(3)https://developer.aliyun.com/article/1559699