JAX 中文文档(十四)(2)https://developer.aliyun.com/article/1559756
jax.random 模块
伪随机数生成的实用程序。
jax.random
包提供了多种例程,用于确定性生成伪随机数序列。
基本用法
>>> seed = 1701 >>> num_steps = 100 >>> key = jax.random.key(seed) >>> for i in range(num_steps): ... key, subkey = jax.random.split(key) ... params = compiled_update(subkey, params, next(batches))
PRNG keys
与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key()
函数生成:
>>> from jax import random >>> key = random.key(0) >>> key Array((), dtype=key<fry>) overlaying: [0 0]
然后,可以在 JAX 的任何随机数生成例程中使用该 key:
>>> random.uniform(key) Array(0.41845703, dtype=float32)
请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:
>>> random.uniform(key) Array(0.41845703, dtype=float32)
如果需要新的随机数,可以使用 jax.random.split()
生成新的子 key:
>>> key, subkey = random.split(key) >>> random.uniform(subkey) Array(0.10536897, dtype=float32)
注意
类型化的 key 数组,例如上述 key
,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32
数组表示,其最终维度表示 key 的位级表示。
两种形式的 key 数组仍然可以通过 jax.random
模块创建和使用。新式的类型化 key 数组使用 jax.random.key()
创建。传统的 uint32
key 数组使用 jax.random.PRNGKey()
创建。
要在两者之间进行转换,使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。
否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:
- 它们有一个额外的尾维度。
- 它们具有数字数据类型 (
uint32
),允许进行通常不用于 key 的操作,例如整数算术。 - 它们不包含有关 RNG 实现的信息。当传统 key 传递给
jax.random
函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。
要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263。
高级
设计和背景
TLDR:JAX PRNG = Threefry counter PRNG + 一个功能数组导向的 分裂模型
更多详细信息,请参阅 docs/jep/263-prng.md。
总结一下,JAX PRNG 还包括但不限于以下要求:
- 确保可重现性,
- 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。
高级 RNG 配置
JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。
- 默认,“threefry2x32”: 基于 Threefry 哈希函数构建的基于计数器的 PRNG。
- 实验性 一种仅包装了 XLA 随机位生成器(RBG)算法的 PRNG。请参阅 TF 文档。
- “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
- “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。
- 这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。
不使用默认 RNG 的可能原因是:
- 可能编译速度较慢(特别是对于 Google Cloud TPU)
- 在 TPU 上执行速度较慢
- 不支持高效的自动分片/分区
这里是一个简短的总结:
属性 | Threefry | Threefry* | rbg | unsafe_rbg | rbg** | unsafe_rbg** |
在 TPU 上最快 | ✅ | ✅ | ✅ | ✅ | ||
可以高效分片(使用 pjit) | ✅ | ✅ | ✅ | |||
在分片中相同 | ✅ | ✅ | ✅ | ✅ | ||
在 CPU/GPU/TPU 上相同 | ✅ | ✅ | ||||
在 JAX/XLA 版本间相同 | ✅ | ✅ |
(*): 设置了jax_threefry_partitionable=1
(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于 jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在 jax.random.split 和 jax.random.fold_in 中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。
要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
API 参考
密钥创建与操作
PRNGKey (seed, *[, impl]) |
给定整数种子创建伪随机数生成器(PRNG)密钥。 |
key (seed, *[, impl]) |
给定整数种子创建伪随机数生成器(PRNG)密钥。 |
key_data (密钥) |
恢复 PRNG 密钥数组下的密钥数据位。 |
wrap_key_data (key_bits_array, *[, impl]) |
将密钥数据位数组包装成 PRNG 密钥数组。 |
fold_in (key, data) |
将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。 |
split (key[, num]) |
将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。 |
clone (key) |
克隆一个密钥以便重复使用。 |
随机抽样器
ball (key, d[, p, shape, dtype]) |
从单位 Lp 球中均匀采样。 |
bernoulli (key[, p, shape]) |
采样给定形状和均值的伯努利分布随机值。 |
beta (key, a, b[, shape, dtype]) |
采样给定形状和浮点数数据类型的贝塔分布随机值。 |
binomial (key, n, p[, shape, dtype]) |
采样给定形状和浮点数数据类型的二项分布随机值。 |
bits (key[, shape, dtype]) |
以无符号整数的形式采样均匀比特。 |
categorical (key, logits[, axis, shape]) |
从分类分布中采样随机值。 |
cauchy (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的柯西分布随机值。 |
chisquare (key, df[, shape, dtype]) |
采样给定形状和浮点数数据类型的卡方分布随机值。 |
choice (key, a[, shape, replace, p, axis]) |
从给定数组中生成随机样本。 |
dirichlet (key, alpha[, shape, dtype]) |
采样给定形状和浮点数数据类型的狄利克雷分布随机值。 |
double_sided_maxwell (key, loc, scale[, …]) |
从双边 Maxwell 分布中采样。 |
exponential (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的指数分布随机值。 |
f (key, dfnum, dfden[, shape, dtype]) |
采样给定形状和浮点数数据类型的 F 分布随机值。 |
gamma (key, a[, shape, dtype]) |
采样给定形状和浮点数数据类型的伽马分布随机值。 |
generalized_normal (key, p[, shape, dtype]) |
从广义正态分布中采样。 |
geometric (key, p[, shape, dtype]) |
采样给定形状和浮点数数据类型的几何分布随机值。 |
gumbel (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的 Gumbel 分布随机值。 |
laplace (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的拉普拉斯分布随机值。 |
loggamma (key, a[, shape, dtype]) |
采样给定形状和浮点数数据类型的对数伽马分布随机值。 |
logistic (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的 logistic 随机值。 |
lognormal (key[, sigma, shape, dtype]) |
采样给定形状和浮点数数据类型的对数正态分布随机值。 |
maxwell (key[, shape, dtype]) |
从单边 Maxwell 分布中采样。 |
multivariate_normal (key, mean, cov[, shape, …]) |
采样给定均值和协方差的多变量正态分布随机值。 |
normal (key[, shape, dtype]) |
采样给定形状和浮点数数据类型的标准正态分布随机值。 |
orthogonal (key, n[, shape, dtype]) |
从正交群 O(n) 中均匀采样。 |
pareto (key, b[, shape, dtype]) |
采样给定形状和浮点数数据类型的帕累托分布随机值。 |
permutation (key, x[, axis, independent]) |
返回随机排列的数组或范围。 |
poisson (key, lam[, shape, dtype]) |
采样给定形状和整数数据类型的泊松分布随机值。 |
rademacher (key[, shape, dtype]) |
从 Rademacher 分布中采样。 |
randint (key, shape, minval, maxval[, dtype]) |
用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。 |
[rayleigh (key, scale[, shape, dtype]) |
用给定的形状和浮点数数据类型示例瑞利随机值。 |
t (key, df[, shape, dtype]) |
用给定的形状和浮点数数据类型示例学生 t 分布随机值。 |
triangular (key, left, mode, right[, shape, …]) |
用给定的形状和浮点数数据类型示例三角形随机值。 |
truncated_normal (key, lower, upper[, shape, …]) |
用给定的形状和数据类型示例截断标准正态随机值。 |
uniform (key[, shape, dtype, minval, maxval]) |
用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。 |
[wald (key, mean[, shape, dtype]) |
用给定的形状和浮点数数据类型示例瓦尔德随机值。 |
weibull_min (key, scale, concentration[, …]) |
从威布尔分布中采样。 |
JAX 中文文档(十四)(4)https://developer.aliyun.com/article/1559759