JAX 中文文档(六)(1)https://developer.aliyun.com/article/1559681
贝叶斯推断的自动批处理
原文:
jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html
[外链图片转存中…(img-Fd5vUVOI-1718950514656)]
本笔记演示了一个简单的贝叶斯推断示例,其中自动批处理使用户代码更易于编写、更易于阅读,减少了错误的可能性。
灵感来自@davmre 的一个笔记本。
import functools import itertools import re import sys import time from matplotlib.pyplot import * import jax from jax import lax import jax.numpy as jnp import jax.scipy as jsp from jax import random import numpy as np import scipy as sp
生成一个虚拟的二分类数据集
np.random.seed(10009) num_features = 10 num_points = 100 true_beta = np.random.randn(num_features).astype(jnp.float32) all_x = np.random.randn(num_points, num_features).astype(jnp.float32) y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
编写模型的对数联合函数
我们将编写一个非批处理版本、一个手动批处理版本和一个自动批处理版本。
非批量化
def log_joint(beta): result = 0. # Note that no `axis` parameter is provided to `jnp.sum`. result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.)) result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta)))) return result
log_joint(np.random.randn(num_features))
Array(-213.2356, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching. try: batch_size = 10 batched_test_beta = np.random.randn(batch_size, num_features) log_joint(np.random.randn(batch_size, num_features)) except ValueError as e: print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)]
手动批处理
def batched_log_joint(beta): result = 0. # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis # or setting it incorrectly yields an error; at worst, it silently changes the # semantics of the model. result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.), axis=-1) # Note the multiple transposes. Getting this right is not rocket science, # but it's also not totally mindless. (I didn't get it right on the first # try.) result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)), axis=-1) return result
batch_size = 10 batched_test_beta = np.random.randn(batch_size, num_features) batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291 , -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ], dtype=float32)
使用 vmap 进行自动批处理
它只是有效地工作。
vmap_batched_log_joint = jax.vmap(log_joint) vmap_batched_log_joint(batched_test_beta)
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291 , -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ], dtype=float32)
自包含的变分推断示例
从上面复制了一小段代码。
设置(批量化的)对数联合函数
@jax.jit def log_joint(beta): result = 0. # Note that no `axis` parameter is provided to `jnp.sum`. result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.)) result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta)))) return result batched_log_joint = jax.jit(jax.vmap(log_joint))
定义 ELBO 及其梯度
def elbo(beta_loc, beta_log_scale, epsilon): beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi)) elbo = jax.jit(elbo) elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
使用 SGD 优化 ELBO
def normal_sample(key, shape): """Convenience function for quasi-stateful RNG.""" new_key, sub_key = random.split(key) return new_key, random.normal(sub_key, shape) normal_sample = jax.jit(normal_sample, static_argnums=(1,)) key = random.key(10003) beta_loc = jnp.zeros(num_features, jnp.float32) beta_log_scale = jnp.zeros(num_features, jnp.float32) step_size = 0.01 batch_size = 128 epsilon_shape = (batch_size, num_features) for i in range(1000): key, epsilon = normal_sample(key, epsilon_shape) elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad( beta_loc, beta_log_scale, epsilon) beta_loc += step_size * beta_loc_grad beta_log_scale += step_size * beta_log_scale_grad if i % 10 == 0: print('{}\t{}'.format(i, elbo_val))
0 -180.8538818359375 10 -113.06045532226562 20 -102.73727416992188 30 -99.787353515625 40 -98.90898132324219 50 -98.29745483398438 60 -98.18632507324219 70 -97.57972717285156 80 -97.28599548339844 90 -97.46996307373047 100 -97.4771728515625 110 -97.5806655883789 120 -97.4943618774414 130 -97.50271606445312 140 -96.86396026611328 150 -97.44197845458984 160 -97.06941223144531 170 -96.84028625488281 180 -97.21336364746094 190 -97.56503295898438 200 -97.26397705078125 210 -97.11979675292969 220 -97.39595031738281 230 -97.16831970214844 240 -97.118408203125 250 -97.24345397949219 260 -97.29788970947266 270 -96.69286346435547 280 -96.96438598632812 290 -97.30055236816406 300 -96.63591766357422 310 -97.0351791381836 320 -97.52909088134766 330 -97.28811645507812 340 -97.07321166992188 350 -97.15619659423828 360 -97.25881958007812 370 -97.19515228271484 380 -97.13092041015625 390 -97.11726379394531 400 -96.938720703125 410 -97.26676940917969 420 -97.35322570800781 430 -97.21007537841797 440 -97.28434753417969 450 -97.1630859375 460 -97.2612533569336 470 -97.21343994140625 480 -97.23997497558594 490 -97.14913940429688 500 -97.23527526855469 510 -96.93419647216797 520 -97.21209716796875 530 -96.82575988769531 540 -97.01284790039062 550 -96.94175720214844 560 -97.16520690917969 570 -97.29165649414062 580 -97.42941284179688 590 -97.24370574951172 600 -97.15222930908203 610 -97.49844360351562 620 -96.9906997680664 630 -96.88956451416016 640 -96.89968872070312 650 -97.13793182373047 660 -97.43705749511719 670 -96.99235534667969 680 -97.15623474121094 690 -97.1869125366211 700 -97.11160278320312 710 -97.78105163574219 720 -97.23226165771484 730 -97.16206359863281 740 -96.99581909179688 750 -96.6672134399414 760 -97.16795349121094 770 -97.51435089111328 780 -97.28900146484375 790 -96.91226196289062 800 -97.17100524902344 810 -97.29047393798828 820 -97.16242980957031 830 -97.19107055664062 840 -97.56382751464844 850 -97.00194549560547 860 -96.86555480957031 870 -96.76338195800781 880 -96.83660888671875 890 -97.12178039550781 900 -97.09554290771484 910 -97.0682373046875 920 -97.11947631835938 930 -96.87930297851562 940 -97.45624542236328 950 -96.69279479980469 960 -97.29376220703125 970 -97.3353042602539 980 -97.34962463378906 990 -97.09675598144531
显示结果
虽然覆盖率不及理想,但也不错,而且没有人说变分推断是精确的。
figure(figsize=(7, 7)) plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars') plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') plot_scale = 3 plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') xlabel('True beta') ylabel('Estimated beta') legend(loc='best')
<matplotlib.legend.Legend at 0x7f6a2c3c86a0>
在多主机和多进程环境中使用 JAX
介绍
本指南解释了如何在 GPU 集群和Cloud TPU pod 等环境中使用 JAX,在这些环境中,加速器分布在多个 CPU 主机或 JAX 进程上。我们将这些称为“多进程”环境。
本指南专门介绍了如何在多进程设置中使用集体通信操作(例如 jax.lax.psum()
),尽管根据您的用例,其他通信方法也可能有用(例如 RPC,mpi4jax)。如果您尚未熟悉 JAX 的集体操作,建议从分片计算部分开始。在 JAX 的多进程环境中,重要的要求是加速器之间的直接通信链路,例如 Cloud TPU 的高速互连或NCCL 用于 GPU。这些链路允许集体操作在多个进程的加速器上高性能运行。
多进程编程模型
关键概念:
- 您必须在每个主机上至少运行一个 JAX 进程。
- 您应该使用
jax.distributed.initialize()
初始化集群。 - 每个进程都有一组独特的本地设备可以访问。全局设备是所有进程的所有设备集合。
- 使用标准的 JAX 并行 API,如
jit()
(参见分片计算入门教程)和shard_map()
。jax.jit 仅接受全局形状的数组。shard_map 允许您按设备形状进行降级。 - 确保所有进程按照相同顺序运行相同的并行计算。
- 确保所有进程具有相同数量的本地设备。
- 确保所有设备相同(例如,全部为 V100 或全部为 H100)。
启动 JAX 进程
与其他分布式系统不同,其中单个控制节点管理多个工作节点,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为单程序多数据(SPMD)模型。通常,在每个进程中运行相同的 JAX Python 程序,每个进程的执行之间只有轻微差异(例如,不同的进程将加载不同的输入数据)。此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会从单个程序调用自动启动多个进程。
(对于多个进程的要求,这就是为什么本指南不作为笔记本提供的原因——我们目前没有好的方法来从单个笔记本管理多个 Python 进程。)
初始化集群
要初始化集群,您应该在每个进程的开始调用 jax.distributed.initialize()
。jax.distributed.initialize()
必须在程序中的任何 JAX 计算执行之前早些时候调用。
API jax.distributed.initialize()
接受几个参数,即:
coordinator_address
:集群中进程 0 的 IP 地址,以及该进程上可用的一个端口。进程 0 将启动一个通过该 IP 地址和端口暴露的 JAX 服务,集群中的其他进程将连接到该服务。coordinator_bind_address
:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与coordinator_address
相同端口的所有可用接口进行绑定。num_processes
:集群中的进程数process_id
:本进程的 ID 号码,范围为[0 .. num_processes)
。local_device_ids
:将当前进程的可见设备限制为local_device_ids
。
例如,在 GPU 上,典型用法如下:
import jax jax.distributed.initialize(coordinator_address="192.168.0.1:1234", num_processes=2, process_id=0)
在 Cloud TPU、Slurm 和 Open MPI 环境中,你可以简单地调用 jax.distributed.initialize()
而无需参数。参数的默认值将自动选择。在使用 Slurm 和 Open MPI 运行 GPU 时,假定每个 GPU 启动一个进程,即每个进程只分配一个可见本地设备。否则假定每个主机启动一个进程,即每个进程将分配所有本地设备。只有当通过 mpirun
/mpiexec
启动 JAX 进程时才会使用 Open MPI 自动初始化。
import jax jax.distributed.initialize()
在当前 TPU 上,调用 jax.distributed.initialize()
目前是可选的,但建议使用,因为它启用了额外的检查点和健康检查功能。
本地与全局设备
在开始从您的程序中运行多进程计算之前,了解本地和全局设备之间的区别是很重要的。
进程的本地设备是它可以直接寻址和启动计算的设备。 例如,在 GPU 集群上,每个主机只能在直接连接的 GPU 上启动计算。在 Cloud TPU pod 上,每个主机只能在直接连接到该主机的 8 个 TPU 核心上启动计算(有关更多详情,请参阅Cloud TPU 系统架构文档)。你可以通过 jax.local_devices()
查看进程的本地设备。
全局设备是跨所有进程的设备。 一个计算可以跨进程的设备并通过设备之间的直接通信链路执行集体操作,只要每个进程在其本地设备上启动计算即可。你可以通过 jax.devices()
查看所有可用的全局设备。一个进程的本地设备总是全局设备的一个子集。
运行多进程计算
那么,你到底如何运行涉及跨进程通信的计算呢? 使用与单进程中相同的并行评估 API!
例如,shard_map()
可以用于在多个进程间并行计算。(如果您还不熟悉如何使用 shard_map
在单个进程内的多个设备上运行,请参阅分片计算介绍教程。)从概念上讲,这可以被视为在跨主机分片的单个数组上运行 pmap,其中每个主机只“看到”其本地分片的输入和输出。
下面是多进程 pmap 的实际示例:
# The following is run in parallel on each host on a GPU cluster or TPU pod slice. >>> import jax >>> jax.distributed.initialize() # On GPU, see above for the necessary arguments. >>> jax.device_count() # total number of accelerator devices in the cluster 32 >>> jax.local_device_count() # number of accelerator devices attached to this host 8 # The psum is performed over all mapped devices across the pod slice >>> xs = jax.numpy.ones(jax.local_device_count()) >>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
非常重要的是,所有进程以相同的跨进程计算顺序运行。 在每个进程中运行相同的 JAX Python 程序通常就足够了。尽管运行相同程序,但仍需注意可能导致不同顺序计算的一些常见陷阱:
- 将不同形状的输入传递给同一并行函数的进程可能导致挂起或不正确的返回值。只要它们在进程间产生相同形状的每设备数据分片,不同形状的输入是安全的;例如,传递不同的前导批次大小以在不同的本地设备数上运行是可以的,但是每个进程根据不同的最大示例长度填充其批次是不行的。
- “最后一批”问题发生在并行函数在(训练)循环中调用时,其中一个或多个进程比其余进程更早退出循环。这将导致其余进程挂起,等待已经完成的进程开始计算。
- 基于集合的非确定性顺序的条件可能导致代码进程挂起。例如,在当前 Python 版本上遍历
set
或者 Python 3.7 之前的dict
可能会导致不同进程的顺序不同,即使插入顺序相同也是如此
JAX 中文文档(六)(3)https://developer.aliyun.com/article/1559683