JAX 中文文档(七)(4)

简介: JAX 中文文档(七)

JAX 中文文档(七)(3)https://developer.aliyun.com/article/1559699


分布式数据加载在多主机/多进程环境中

原文:jax.readthedocs.io/en/latest/distributed_data_loading.html

这个高级指南演示了如何执行分布式数据加载——当你在多主机或多进程环境中运行 JAX 时,用于 JAX 计算的数据被分布在多个进程中。本文档涵盖了分布式数据加载的整体方法,以及如何将其应用于数据并行(更简单)和模型并行(更复杂)的工作负载。

分布式数据加载通常比起其它方法更高效(数据分割在各个进程之间),但同时也更复杂。例如:1)在单一进程中加载整个全局数据,将其分割并通过  RPC 发送到其它进程需要的部分;和  2)在所有进程中加载整个全局数据,然后在每个进程中只使用需要的部分。加载整个全局数据通常更简单但更昂贵。例如,在机器学习中,训练循环可能会因等待数据而阻塞,并且每个进程会使用额外的网络带宽。

注意

当使用分布式数据加载时,每个设备(例如每个 GPU 或  TPU)必须访问其需要运行计算的输入数据分片。这通常使得分布式数据加载比前述的替代方案更复杂和具有挑战性。如果错误的数据分片最终出现在错误的设备上,计算仍然可以正常运行,因为计算无法知道输入数据“应该”是什么。然而,最终结果通常是不正确的,因为输入数据与预期不同。

加载jax.Array的一般方法

考虑一个情况,从未由 JAX 生成的原始数据创建单个jax.Array。这些概念适用于不仅限于加载批量数据记录,例如任何未直接由 JAX 计算产生的多进程jax.Array。例如:1)从检查点加载模型权重;或者 2)加载大型空间分片图像。

每个jax.Array都有一个相关的Sharding,描述了每个全局设备所需的全局数据的哪个分片。当你从头创建一个jax.Array时,你还需要创建其Sharding。这是 JAX 理解数据在各个设备上布局的方式。你可以创建任何你想要的Sharding。在实践中,通常根据你正在实现的并行策略选择一个Sharding(稍后在本指南中将更详细地了解数据和模型并行)。你也可以根据原始数据在每个进程中如何生成来选择一个Sharding

一旦定义了Sharding,你可以使用addressable_devices()为当前进程需要加载数据的设备提供一个设备列表。(注:术语“可寻址设备”是“本地设备”的更一般版本。目标是确保每个进程的数据加载器为其所有本地设备提供正确的数据。)

示例

例如,考虑一个(64, 128)jax.Array,你需要将其分片到 4 个进程,每个进程有 2 个设备(总共 8 个设备)。这将导致 8 个唯一的数据分片,每个设备一个。有许多分片jax.Array的方法。你可以沿着jax.Array的第二维进行 1D 分片,每个设备得到一个(64, 16)的分片,如下所示:


在上图中,每个数据分片都有自己的颜色,表示哪个进程需要加载该分片。例如,假设进程0的 2 个设备包含分片AB,对应于全局数据的第一个(64, 32)部分。

你可以选择不同的分片到设备的分布方式。例如:


这里是另一个示例——二维分片:


但是,无论jax.Array如何分片,你都必须确保每个进程的数据加载器提供/加载全局数据所需的分片。有几种高级方法可以实现这一点:1)在每个进程中加载全局数据;2)使用每设备数据流水线;3)使用合并的每进程数据流水线;4)以某种方便的方式加载数据,然后在计算中重新分片。

选项 1:在每个进程中加载全局数据


使用此选项,每个进程:

  1. 加载所需的完整值;并且
  2. 仅将所需的分片传输到该进程的本地设备。

这并不是一个高效的分布式数据加载方法,因为每个进程都会丢弃其本地设备不需要的数据,并且总体加载的数据量可能会比必要的要多。但这个选项可以运行,并且相对简单实现,对于某些工作负载的性能开销可能是可以接受的(例如,如果全局数据量较小)。

选项 2:使用每设备数据流水线


在此选项中,每个进程为其每个本地设备设置一个数据加载器(即,每个设备仅为其所需的数据分片设置自己的数据加载器)。

这在加载数据方面非常高效。有时,独立考虑每个设备可能比一次性考虑所有进程的本地设备更简单(参见下面的选项 3:使用合并的每进程数据流水线)。然而,多个并发数据加载器有时会导致性能问题。

选项 3:使用集中的每个进程数据管道


如果选择此选项,每个过程:

  1. 设置一个单一的数据加载器,加载所有本地设备所需的数据;然后
  2. 在传输到每个本地设备之前对本地数据进行分片。

这是最有效的分布式加载方式。然而,这也是最复杂的,因为需要逻辑来确定每个设备所需的数据,以及创建一个单一的数据加载,仅加载所有这些数据(理想情况下,没有其他额外的数据)。

选项 4:以某种便捷方式加载数据,在计算中重新分片


这个选项比前述选项(从 1 到 3)更难解释,但通常比它们更容易实现。

想象一个场景,设置数据加载器以精确加载您需要的数据,无论是为每个设备还是每个进程加载器,这可能很困难或几乎不可能。然而,仍然可以为每个进程设置一个数据加载器,加载数据的1 / num_processes,只是没有正确的分片。

然后,继续使用您之前的 2D 示例分片,假设每个过程更容易加载数据的单个列:

然后,您可以创建一个带有表示每列数据的Shardingjax.Array,直接将其传递到计算中,并使用jax.lax.with_sharding_constraint()立即将列分片输入重新分片为所需的分片。由于数据在计算中重新分片,它将通过加速器通信链路(例如 TPU ICI 或 NVLink)进行重新分片。

选项 4 与选项 3(使用集中的每个进程数据管道)具有类似的优点:

  • 每个过程仍然具有单个数据加载器;和
  • 全局数据在所有过程中仅加载一次;和
  • 全局数据的额外好处在于提供如何加载数据的更大灵活性。

然而,这种方法使用加速器互连带宽执行重新分片,可能会降低某些工作负载的速度。选项 4 还要求将输入数据表示为单独的Sharding,除了目标Sharding

复制

复制描述了多个设备具有相同数据分片的过程。上述提到的一般选项(选项 1 到 4)仍然适用于复制。唯一的区别是某些过程可能会加载相同的数据分片。本节描述了完全复制和部分复制。

全部复制

完全复制是所有设备都具有数据的完整副本的过程(即,“分片”是整个数组值)。

在下面的示例中,由于总共有 8 个设备(每个进程 2 个),您将得到完整数据的 8 个副本。数据的每个副本都未分片,即副本存在于单个设备上:


部分复制

部分复制描述了一个过程,其中数据有多个副本,并且每个副本分片到多个设备上。对于给定的数组值,通常有许多执行部分复制的可能方法(注意:对于给定的数组形状,总是存在单一完全复制的Sharding)。

下面是两个可能的示例。

在下面的第一个示例中,每个副本都分片到进程的两个本地设备上,总共有 4 个副本。这意味着每个进程都需要加载完整的全局数据,因为其本地设备将具有数据的完整副本。


在下面的第二个示例中,每个副本仍然分片到两个设备上,但每个设备对是分布在两个不同的进程中。进程 0(粉色)和进程 1(黄色)都只需要加载数据的第一行,而进程 2(绿色)和进程 3(蓝色)都只需要加载数据的第二行:


现在您已经了解了创建 jax.Array 的高级选项,让我们将它们应用于机器学习应用程序的数据加载。

数据并行性

纯数据并行性(无模型并行性)中:

  • 您在每个设备上复制模型;和
  • 每个模型副本(即每个设备)接收不同的副本批次数据。


将输入数据表示为单个 jax.Array 时,该数组包含此步骤所有副本的数据(称为全局批处理),其中 jax.Array 的每个分片包含单个副本批处理。您可以将其表示为跨所有设备的 1D 分片(请查看下面的示例)——换句话说,全局批处理由所有副本批处理沿批处理轴连接在一起组成。


应用此框架,您可以得出结论,进程 0 应该获取全局批处理的第一个季度(8 的 2 分之一),而进程 1 应该获取第二个季度,依此类推。

但是,您如何知道第一个季度是什么?您如何确保进程 0 获得第一个季度?幸运的是,数据并行性有一个非常重要的技巧,这意味着您不必回答这些问题,并使整个设置更简单。

关于数据并行性的重要技巧

诀窍在于您不需要关心哪个每副本批次会落到哪个副本上。因此,不管哪个进程加载了一个批次都无所谓。原因在于每个设备都对应执行相同操作的模型副本,每个设备获取全局批次中的每个每副本批次都无关紧要。

这意味着您可以自由重新排列全局批次中的每副本批次。换句话说,您可以随机化每个设备获取哪个数据分片。

例如:


通常,重新排列jax.Array的数据分片并不是一个好主意 —— 事实上,您是在对jax.Array的值进行置换!然而,对于数据并行处理来说,全局批次顺序并不重要,您可以自由重新排列全局批次中的每个每副本批次,正如前面已经提到的那样。

这简化了数据加载,因为这意味着每个设备只需要独立的每副本批次流,大多数数据加载器可以通过为每个进程创建一个独立的流水线并将结果分割为每副本批次来轻松实现。


这是选项 2: 合并每进程数据流水线的一个实例。您也可以使用其他选项(如 0、1 和 3,在本文档的早期部分有介绍),但这个选项相对简单和高效。

这是一个如何使用 tf.data 实现此设置的示例:

import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
    [np.ones((16, 3)) * i for i in range(100)])
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
per_process_batch_size = per_process_batch.shape[0]  # adjust if your batch dim
                                                     # isn't 0
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
assert per_process_batch_size % per_replica_batch_size == 0, \
  "This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch, jax.local_device_count())
# Thanks to the very important trick about data parallelism, no need to care what
# order the devices appear in the sharding.
sharding = jax.sharding.PositionalSharding(jax.devices())
# PositionalSharding must have same rank as data being sharded.
sharding = sharding.reshape((jax.device_count(),) +
                            (1,) * (per_process_batch.ndim - 1))
global_batch_size = per_replica_batch_size * jax.device_count()
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
global_batch_array = jax.make_array_from_single_device_arrays(
    global_batch_shape, sharding,
    # Thanks again to the very important trick, no need to care which device gets
    # which per-replica batch.
    arrays=[jax.device_put(batch, device)
            for batch, device 
            in zip(per_replica_batches, sharding.addressable_devices)])
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
        per_replica_batches[0].shape) 

数据 + 模型并行处理

模型并行处理中,您将每个模型副本分片到多个设备上。如果您使用纯模型并行处理(不使用数据并行处理):

  • 只有一个模型副本分片在所有设备上;并且
  • 数据通常在所有设备上完全复制。

本指南考虑了同时使用数据和模型并行处理的情况:

  • 您将多个模型副本中的每一个分片到多个设备上;并且
  • 您可以部分复制数据到每个模型副本 —— 每个模型副本中的设备得到相同的每副本批次,不同模型副本之间的设备得到不同的每副本批次。

进程内的模型并行处理

对于数据加载,最简单的方法可以是在单个进程的本地设备中将每个模型副本分片。

举个例子,让我们切换到每个有 4 个设备的 2 个进程(而不是每个有 2 个设备的 4 个进程)。考虑一个情况,每个模型副本都分片在单个进程的 2 个本地设备上。这导致每个进程有 2 个模型副本,总共 4 个模型副本,如下所示:


在这里,再次强调,输入数据表示为单个jax.Array,其中每个分片是一个每副本批次的 1D 分片,有一个例外:

  • 不同于纯数据并行情况,你引入了部分复制,并制作了 1D 分片全局批次的 2 个副本。
  • 这是因为每个模型副本由两个设备组成,每个设备都需要一个副本批次的拷贝。


将每个模型副本保持在单个进程内可以使事情变得更简单,因为你可以重用上述纯数据并行设置,除非你还需要复制每个副本的批次:


注意

同样重要的是要将每个副本批次复制到正确的设备上! 虽然数据并行性的一个非常重要的技巧意味着你不在乎哪个批次最终落到哪个副本上,但你确实关心单个副本只得到一个批次

例如,这是可以的:


但是,如果你在加载每批数据到本地设备时不小心,可能会意外地创建未复制的数据,即使分片(和并行策略)表明数据已经复制:


如果你意外地创建了应该在单个进程内复制的未复制数据的jax.Array,JAX 将会报错(不过对于跨进程的模型并行性,情况并非总是如此;请参阅下一节)。

下面是使用tf.data实现每个进程模型并行性和数据并行性的示例:

import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
#         (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
per_process_batch_size = per_process_batch.shape[0]  # adjust if your batch dim
                                                     # isn't 0
per_replica_batch_size = (per_process_batch_size //
                          num_model_replicas_per_process)
assert per_process_batch_size % per_replica_batch_size == 0, \
  "This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch,
                               num_model_replicas_per_process)
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
                         for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
  num_processes = len(set(d.process_index for d in replica_devices))
  assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_size = per_replica_batch_size * num_model_replicas_total
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
# Create the final jax.Array using jax.make_array_from_callback. The callback
# will be called for each local device, and passed the N-D numpy-style index
# that describes what shard of the global data that device should receive.
#
# You don't need care exactly which index is passed in due to the very important data
# parallelism, but you do use the index argument to make sure you replicate each
# per-replica batch correctly -- the `index` argument will be the same for
# devices in the same model replica, and different for devices in different
# model replicas.
index_to_batch  = {}
def callback(index: tuple[slice, ...]) -> np.ndarray:
  # Python `slice` objects aren't hashable, so manually create dict key.
  index_key = tuple((slice_.start, slice_.stop) for slice_ in index)
  if index_key not in index_to_batch:
    # You don't care which per-replica batch goes to which replica, just take the
    # next unused one.
    index_to_batch[index_key] = per_replica_batches[len(index_to_batch)]
  return index_to_batch[index_key]
global_batch_array = jax.make_array_from_callback(
    global_batch_shape, sharding, callback)
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
        per_replica_batches[0].shape) 

跨进程的模型并行性

当模型副本分布在不同进程中时,可能会变得更加有趣,无论是:

  • 因为单个副本无法适应一个进程;或者
  • 因为设备分配并不是按照这种方式设置的。

例如,回到之前的设置,4 个每个有 2 个设备的进程,如果你像这样为副本分配设备:


这与之前的每个进程模型并行性示例相同的并行策略 - 4 个模型副本,每个副本分布在 2 个设备上。唯一的区别在于设备分配 - 每个副本的两个设备分布在不同的进程中,每个进程只负责每个副本批次的一份拷贝(但是对于两个副本)。

像这样跨进程分割模型副本可能看起来是一种随意且不必要的做法(在这个例子中,这可能是这样),但实际的部署可能会采用这种设备分配方式,以最大程度地利用设备之间的通信链路。

数据加载现在变得更加复杂,因为跨进程需要一些额外的协调。在纯数据并行和每个进程模型并行的情况下,每个进程只需加载唯一的数据流即可。现在某些进程必须加载相同的数据,而另一些进程必须加载不同的数据。在上述示例中,进程02(分别显示为粉色和绿色)必须加载相同的 2 个每个副本的批次,并且进程13(分别显示为黄色和蓝色)也必须加载相同的 2 个每个副本的批次(但不同于进程02的批次)。

此外,每个进程不混淆它的 2 个每个副本的批次是非常重要的。虽然您不关心哪个批次落在哪个副本(这是关于数据并行的一个非常重要的技巧),但您需要确保同一个副本中的所有设备获取相同的批次。例如,以下情况是不好的:


注意

截至 2023 年 8 月,JAX 无法检测到如果jax.Array在进程之间的分片应该复制但实际没有复制,则在运行计算时会产生错误结果。因此,请务必注意避免这种情况!

要在每个设备上获取正确的每个副本批次,您需要将全局输入数据表示为以下的jax.Array


JAX 中文文档(七)(5)https://developer.aliyun.com/article/1559701

相关文章
|
4月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
39 1
|
4月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
40 1
|
4月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
29 1
|
4月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
36 0
JAX 中文文档(六)(1)
|
4月前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
29 0
|
4月前
|
数据可视化 TensorFlow 算法框架/工具
JAX 中文文档(三)(2)
JAX 中文文档(三)
63 0
|
4月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
37 0
|
4月前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
41 0
|
4月前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
34 0
|
4月前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
53 0