JAX 中文文档(十六)(2)https://developer.aliyun.com/article/1559727
jaxlib 0.4.14(2023 年 7 月 27 日)
- 弃用
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,不再支持 Python 3.8。
jax 0.4.13(2023 年 6 月 22 日)
- 更改
jax.jit现在允许将None传递给in_shardings和out_shardings。语义如下:
- 对于
in_shardings,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
jax.experimental.pjit.pjit也允许将None传递给in_shardings和out_shardings。语义如下:
- 如果未提供网格上下文管理器,则 JAX 可自由选择所需的分片方式。
- 对于
in_shardings,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
- 如果提供了网格上下文管理器,
None将意味着该值将在网格的所有设备上复制。
- Executable.cost_analysis() 在 Cloud TPU 上可用
- 如果正在使用非允许的
jaxlib插件,则添加了警告。 - 添加了
jax.tree_util.tree_leaves_with_path。 None不是jax.experimental.multihost_utils.host_local_array_to_global_array或jax.experimental.multihost_utils.global_array_to_host_local_array的有效输入。如果您希望复制您的输入,请使用jax.sharding.PartitionSpec()。
- Bug 修复
- 在 CUDA 12 发布中修复了错误的轮子名称(#16362);正确的轮子名称为
cudnn89而不是cudnn88。
- 弃用
jax.experimental.jax2tf.convert()的native_serialization_strict_checks参数已被弃用,推荐使用新的native_serializaation_disabled_checks(#16347)。
jaxlib 0.4.13(2023 年 6 月 22 日)
- 更改
- 将 Windows 仅 CPU 轮子添加到
jaxlibPypi 发布中。
- Bug 修复
__cuda_array_interface__在之前的 jaxlib 版本中出现问题,现已修复(#16440)。- 并行 CUDA 内核跟踪现在默认启用于 NVIDIA GPU。
jax 0.4.12(2023 年 6 月 8 日)
- 更改
- 弃用
jax.abstract_arrays及其内容已被弃用。请参阅:mod:jax.core中的相关功能。jax.numpy.alltrue:使用jax.numpy.all。这遵循了 NumPy 版本 1.25.0 中numpy.alltrue的弃用。jax.numpy.sometrue:使用jax.numpy.any。这遵循了 NumPy 版本 1.25.0 中numpy.sometrue的弃用。jax.numpy.product:使用jax.numpy.prod。这遵循了 NumPy 版本 1.25.0 中numpy.product的弃用。jax.numpy.cumproduct:使用jax.numpy.cumprod。这遵循了 NumPy 版本 1.25.0 中numpy.cumproduct的弃用。jax.sharding.OpShardingSharding已被移除,因为它已经弃用了 3 个月。
jaxlib 0.4.12 (2023 年 6 月 8 日)
- 变更
- 包含了 Hopper(SM 版本 9.0+)GPU 的 PTX/SASS。之前的 jaxlib 版本应该可以在 Hopper 上工作,但第一次执行 JAX 操作时可能会有较长的 JIT 编译延迟。
- Bug 修复
- 修复了在 Python 3.11 下 JAX 生成的 Python 回溯中源代码行信息不正确的问题。
- 修复了在 JAX 生成的 Python 回溯的帧中打印本地变量时崩溃的问题(#16027)。
jax 0.4.11 (2023 年 5 月 31 日)
- 弃用
- 根据 API 兼容性政策,在 3 个月的弃用期后,已移除以下 API:
jax.experimental.PartitionSpec:使用jax.sharding.PartitionSpec。jax.experimental.maps.Mesh:使用jax.sharding.Mesh。jax.experimental.pjit.NamedSharding:使用jax.sharding.NamedSharding。jax.experimental.pjit.PartitionSpec:使用jax.sharding.PartitionSpec。jax.experimental.pjit.FROM_GDA。请将分片的jax.Array对象作为输入传递,并删除pjit的可选in_shardings参数。jax.interpreters.pxla.PartitionSpec:使用jax.sharding.PartitionSpec。jax.interpreters.pxla.Mesh:使用jax.sharding.Mesh。jax.interpreters.xla.Buffer:使用jax.Array。jax.interpreters.xla.Device:使用jax.Device。jax.interpreters.xla.DeviceArray:使用jax.Array。jax.interpreters.xla.device_put:使用jax.device_put。jax.interpreters.xla.xla_call_p:使用jax.experimental.pjit.pjit_p。with_sharding_constraint的axis_resources参数已被移除。请改用shardings。
jaxlib 0.4.11 (2023 年 5 月 31 日)
- 变更
- 向
Device添加了memory_stats()方法。如果支持,它将返回一个包含字符串统计名称和整数值的字典,例如"bytes_in_use",如果平台不支持内存统计,则返回 None。具体的统计数据可能因平台而异。目前仅在 Cloud TPU 上实现。 - 重新添加了对 CPU 设备上 Python 缓冲协议(
memoryview)的支持。
jax 0.4.10 (2023 年 5 月 11 日)
jaxlib 0.4.10 (2023 年 5 月 11 日)
- 变更
- 修复了阻止上一个版本在 Mac M1 上运行的
'apple-m1' is not a recognized processor for this target (ignoring processor)问题。
jax 0.4.9 (2023 年 5 月 9 日)
- 变更
experimental_cpp_jit、experimental_cpp_pjit和experimental_cpp_pmap标志已被移除。它们现在始终开启。- TPU 上奇异值分解(SVD)的准确性已经提高(需要 jaxlib 0.4.9)。
- 废弃功能
jax.experimental.gda_serialization已废弃,并已重命名为jax.experimental.array_serialization。请更改您的导入以使用jax.experimental.array_serialization。pjit的in_axis_resources和out_axis_resources参数已废弃。请分别使用in_shardings和out_shardings。- 函数
jax.numpy.msort已被移除。自 JAX v0.4.1 起已被废弃。请使用jnp.sort(a, axis=0)代替。 in_parts和out_parts参数已从jax.xla_computation中移除,因为它们只与sharded_jit一起使用,并且sharded_jit已不再使用。- 自从很久以来未被使用,
instantiate_const_outputs参数已从jax.xla_computation中移除。
jaxlib 0.4.9(2023 年 5 月 9 日)
jax 0.4.8(2023 年 3 月 29 日)
- 破坏性变更
- Cloud TPU 运行时的一个重要组件已升级。这使得以下新功能在 Cloud TPU 上可用:
jax.debug.print()、jax.debug.callback()和jax.debug.breakpoint()现在在 Cloud TPU 上可用。- 自动 TPU 内存碎片整理
- 在新的运行时组件上,不再支持
jax.experimental.host_callback()在 Cloud TPU 上的使用。如果新的jax.debugAPI 不能满足您的需求,请在JAX 问题跟踪器上提出问题。
旧的运行时组件将通过设置环境变量JAX_USE_PJRT_C_API_ON_TPU=false至少在接下来的三个月内可用。如果您发现需要出于任何原因禁用新的运行时,请在JAX 问题跟踪器上告知我们。
- 变更
- 最低 jaxlib 版本已从 0.4.6 提升至 0.4.7。
- 废弃功能
- 支持 CUDA 11.4 已被移除。JAX GPU 版本仅支持 CUDA 11.8 和 CUDA 12。如果使用旧版 CUDA 构建 jaxlib 可能会正常工作。
pmap的global_arg_shapes参数仅适用于sharded_jit,已从pmap中移除。请迁移到pjit并从pmap中移除global_arg_shapes。
jax 0.4.7(2023 年 3 月 27 日)
- 变更
- 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration,不再支持禁用
jax.config.jax_array。 - 不再支持禁用
jax.config.jax_jit_pjit_api_merge。 jax.experimental.jax2tf.convert()现在支持native_serialization参数,使用 JAX 的本机降级到 StableHLO 以获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降级到 TensorFlow 操作。这简化了内部操作,并增加了您序列化内容与 JAX 本机语义匹配的信心。详见文档。作为这一变更的一部分,配置标志--jax2tf_default_experimental_native_lowering已重命名为--jax2tf_native_serialization。- JAX 现在依赖于
ml_dtypes,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。 - JAX 现在要求使用 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。
- 弃用信息
- 类型
jax.numpy.DeviceArray已弃用。请改用jax.Array,它是其别名。 - 类型
jax.interpreters.pxla.ShardedDeviceArray已弃用。请改用jax.Array。 - 通过位置传递额外参数给
jax.numpy.ndarray.at()已被弃用。例如,不要使用x.at[i].get(True),而是使用x.at[i].get(indices_are_sorted=True) jax.interpreters.xla.device_put已被弃用。请使用jax.device_put。jax.interpreters.pxla.device_put已被弃用。请使用jax.device_put。jax.experimental.pjit.FROM_GDA已被弃用。请将分片的 jax.Arrays 作为输入,并移除 pjit 中的in_shardings参数,因为它是可选的。
jaxlib 0.4.7(2023 年 3 月 27 日)
变更:
- jaxlib 现在依赖于
ml_dtypes,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。
jax 0.4.6(2023 年 3 月 9 日)
- 变更
jax.tree_util现在包含一组允许用户为其自定义 pytree 节点定义键的 API。
tree_flatten_with_path可以展平树并返回每个叶子及其键路径。tree_map_with_path可以映射一个接受键路径作为参数的函数。register_pytree_with_keys用于注册自定义 pytree 节点中键路径和叶子的外观。keystr用于漂亮地打印键路径。
jax2tf.call_tf()现在有一个新参数output_shape_dtype(默认为None),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()能够在形状多态性存在的情况下工作。(#14734)
- 弃用信息
jax.tree_util中的旧键路径 API 已被弃用,并将在 2023 年 3 月 10 日后的 3 个月内移除:
register_keypaths:请使用jax.tree_util.register_pytree_with_keys()替代。AttributeKeyPathEntry:请改用GetAttrKey。GetitemKeyPathEntry:请改用SequenceKey或DictKey。
jaxlib 0.4.6(2023 年 3 月 9 日)
jax 0.4.5(2023 年 3 月 2 日)
- 弃用信息
jax.sharding.OpShardingSharding已重命名为jax.sharding.GSPMDSharding。jax.sharding.OpShardingSharding将在 2023 年 2 月 17 日后的 3 个月内移除。- 下列
jax.Array方法已被弃用,并将在 2023 年 2 月 23 日后的 3 个月内移除:
jax.Array.broadcast:请使用jax.lax.broadcast()替代。jax.Array.broadcast_in_dim:请使用jax.lax.broadcast_in_dim()替代。jax.Array.split:请使用jax.numpy.split()替代。
jax 0.4.4(2023 年 2 月 16 日)
- 变更
jit和pjit的实现已合并。合并 jit 和 pjit 改变了 JAX 的内部实现,但不影响 JAX 的公共 API。之前,jit是一种最终风格的原语。最终风格意味着尽可能延迟创建 jaxpr 并将变换堆叠在一起。随着jit-pjit实现的合并,jit变成了一种初始风格的原语,这意味着我们尽早追踪到 jaxpr。更多信息请参见 autodidax 中的这一部分。转移到初始风格应该简化 JAX 的内部实现,并使得动态形状等功能的开发更加容易。你只能通过环境变量来禁用它,即os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。由于它影响到 JAX 的导入时机,因此必须通过环境变量禁用它,在导入 jax 之前就需要禁用它。with_sharding_constraint的axis_resources参数已弃用。请改用shardings。如果你将其作为参数使用,则无需更改。如果你将其作为关键字参数使用,请改用shardings。axis_resources将在 2023 年 2 月 13 日后的 3 个月内删除。- 添加了
jax.typing模块,用于 JAX 函数的类型注解工具。 - 下列名称已被弃用:
jax.xla.Device和jax.interpreters.xla.Device: 使用jax.Device。jax.experimental.maps.Mesh. 使用jax.sharding.Mesh替代。jax.experimental.pjit.NamedSharding: 使用jax.sharding.NamedSharding。jax.experimental.pjit.PartitionSpec: 使用jax.sharding.PartitionSpec。jax.interpreters.pxla.Mesh: 使用jax.sharding.Mesh。jax.interpreters.pxla.PartitionSpec: 使用jax.sharding.PartitionSpec。
- Breaking Changes
jax.numpy.sum等的initial参数现在要求是一个标量,与对应的 NumPy API 保持一致。以前的行为是对非标量initial值进行广播,这是一个意外的实现细节(#14446)。
jaxlib 0.4.4(2023 年 2 月 16 日)
- Breaking changes
- 默认的
jaxlib构建中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,可以通过使用 Kepler 支持的源码构建jaxlib(通过build.py的--cuda_compute_capabilities=sm_35选项),不过请注意 CUDA 12 已完全停止支持 Kepler GPU。
jax 0.4.3(2023 年 2 月 8 日)
- Breaking changes
- 删除了
jax.scipy.linalg.polar_unitary(),这是一个已弃用的 JAX 扩展到 scipy API 的函数。请改用jax.scipy.linalg.polar()。
- Changes
- 添加了
jax.scipy.stats.rankdata()。
jaxlib 0.4.3(2023 年 2 月 8 日)
jax.Array现在具有非阻塞的is_ready()方法,如果数组已准备就绪则返回True(参见jax.block_until_ready())。
jax 0.4.2(2023 年 1 月 24 日)
- Breaking changes
- 删除了
jax.experimental.callback - 在存在
jax2tf形状多态性的情况下,对带有维度的操作进行了泛化处理,通过将符号维度转换为 JAX 数组来在更多场景下工作。现在,涉及符号维度和np.ndarray的操作在结果用作形状值时可能会引发错误(#14106)。 - 现在,
jaxpr对象在设置属性时会引发错误,以避免问题变异(#14102)
- 变更
jax2tf.call_tf()现在有一个新参数has_side_effects(默认为True),可用于声明实例是否可以被 JAX 优化(如死代码消除)删除或复制(#13980)。- 为了支持
jax2tf形状多态性的floordiv和mod,我们增加了更多支持。之前,存在符号维度时某些除法操作会导致错误(#14108)。
jaxlib 0.4.2(2023 年 1 月 24 日)
- 变更
- 设置
JAX_USE_PJRT_C_API_ON_TPU=1可启用新的 Cloud TPU 运行时,具备自动设备内存碎片整理功能。
jax 0.4.1(2022 年 12 月 13 日)
- 变更
- 根据 JAX 的 Python 和 NumPy 版本支持政策,不再支持 Python 3.7。
- 我们引入了
jax.Array,它是 JAX 中的统一数组类型,涵盖了DeviceArray、ShardedDeviceArray和GlobalDeviceArray类型。jax.Array类型有助于使并行成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一jit和pjit。jax.Array已在 JAX 0.4 中默认启用,并对pjitAPI 进行了一些破坏性更改。jax.Array 迁移指南 可帮助您将代码库迁移到jax.Array。您还可以查看Distributed arrays and automatic parallelization 教程,以理解新概念。 PartitionSpec和Mesh现在不再处于实验阶段。新的 API 端点是jax.sharding.PartitionSpec和jax.sharding.Mesh。jax.experimental.maps.Mesh和jax.experimental.PartitionSpec已被弃用,并将在三个月内移除。with_sharding_constraint的新公共端点是jax.lax.with_sharding_constraint。- 如果与
jax.config一起使用 ABSL 标志,那么在最初从 ABSL 标志填充 JAX 配置选项后,就不再读取或写入 ABSL 标志值。此更改改进了读取jax.config选项的性能,这些选项在 JAX 中广泛使用。 jax2tf.call_tf函数现在使用与嵌入 JAX 计算相同平台的第一个 TF 设备进行 TF 降级。以前,它使用的是 JAX 默认后端的第 0 个设备。- 现在,一些
jax.numpy函数的参数已标记为仅限位置参数,与 NumPy 匹配。 jnp.msort现已废弃,遵循 numpy 1.24 中np.msort的废弃。它将在未来的版本中移除,符合 API 兼容性策略。可以用jnp.sort(a, axis=0)替换。
jaxlib 0.4.1 (2022 年 12 月 13 日)
- 变更
- 支持 Python 3.7 已被放弃,符合 JAX 的 Python 和 NumPy 版本支持政策。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX的行为已更改,现在分配总 GPU 内存的 XX%来预分配,而不是以前使用当前可用 GPU 内存来计算预分配。有关更多详情,请参阅GPU memory allocation。- 废弃的方法
.block_host_until_ready()已被移除。请改用.block_until_ready()。
jax 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jax 0.3.25 (2022 年 11 月 15 日)
- 变更
jax.numpy.linalg.pinv()现在支持hermitian选项。jax.scipy.linalg.hessenberg()现在仅在 CPU 上支持。需要 jaxlib > 0.3.24。- 新函数
jax.lax.linalg.hessenberg(),jax.lax.linalg.tridiagonal()和jax.lax.linalg.householder_product()已添加。Householder 约简目前仅支持 CPU,三对角约简支持 CPU 和 GPU。 - 现在更经济地计算非方阵的
svd和jax.numpy.linalg.pinv的梯度。
- 突破性变更
- 删除了
jax_experimental_name_stack配置选项。 - 将字符串
axis_names参数转换为jax.experimental.maps.Mesh构造函数的单例元组,而不是将字符串解包为字符轴名称序列。
jaxlib 0.3.25 (2022 年 11 月 15 日)
- 变更
- 添加了对 CPU 和 GPU 上三对角约简的支持。
- 添加了对 CPU 上上 Hessenberg 约简的支持。
- Bug 修复
- 修复了一个 bug,导致 JAX 捕获的回溯中的帧被错误地映射到 Python 3.10+下的源行。
jax 0.3.24 (2022 年 11 月 4 日)
- 变更
- JAX 导入速度应更快。现在我们懒惰地导入 scipy,这在 JAX 的导入时间中占据了相当大的部分。
- 设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N可以用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过 1 秒的计算将被缓存。
- 添加了
jax.scipy.stats.mode()。
- 如果在 TPU 上未指定顺序,则
pmap的默认设备顺序现在与单进程作业的jax.devices()匹配。以前两种排序不同,可能导致不必要的复制或内存不足错误。要求排序一致简化了问题。
- 突破性变更
jax.numpy.gradient()现在像jax.numpy中的大多数其他函数一样,禁止传递列表或元组以替代数组(#12958)。jax.numpy.linalg和jax.numpy.fft中的函数现在统一要求输入为数组形式:即不能使用列表和元组代替数组。部分属于#7737。
- 弃用
jax.sharding.MeshPspecSharding已重命名为jax.sharding.NamedSharding。jax.sharding.MeshPspecSharding名称将在 3 个月内删除。
jaxlib 0.3.24(2022 年 11 月 4 日)
- 更改
- 现在在 CPU 上可以使用缓冲器捐赠。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖捐赠未实现的代码。
jax 0.3.23(2022 年 10 月 12 日)
- 更改
- 更新 Colab TPU 驱动程序版本以支持新的 jaxlib 发布。
jax 0.3.22(2022 年 10 月 11 日)
- 更改
- 在 TPU 初始化中添加
JAX_PLATFORMS=tpu,cpu作为默认设置,因此如果无法初始化 TPU,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''以覆盖此行为并自动选择可用的后端(原始默认),或设置JAX_PLATFORMS=cpu以始终使用 CPU,而不管 TPU 是否可用。
- 弃用
- JAX v0.3.8 中弃用的几个测试工具现已从
jax.test_util中移除。
jaxlib 0.3.22(2022 年 10 月 11 日)
jax 0.3.21(2022 年 9 月 30 日)
- GitHub 提交记录。
- 更改
- 持久化编译缓存现在在出错时会发出警告而不是抛出异常(#12582),所以如果缓存出现问题,程序可以继续执行。设置
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true可以恢复此行为。
jax 0.3.20(2022 年 9 月 28 日)
- Bug 修复:
- 添加了上一个发布版本中缺失的
.pyi文件(#12536)。 - 修复了
jax0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。 - 修复了
setup.py注释中pip的错误网址(#12528)。
jaxlib 0.3.20(2022 年 9 月 28 日)
- GitHub 提交记录。
- Bug 修复
- 修复通过
jax_cuda_visible_devices在分布式作业中限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成非常重要(#12533)。
jax 0.3.19(2022 年 9 月 27 日)
- GitHub 提交记录。
- 需要的 jaxlib 版本修复。
jax 0.3.18(2022 年 9 月 26 日)
- GitHub 提交记录。
- 更改
- 提前编译和编译功能(在#7733中跟踪)是稳定和公开的。查看概述和
jax.stages的 API 文档。 - 引入了
jax.Array,用于 JAX 中数组类型的isinstance检查和类型注释。请注意,这包括了对jax.numpy.ndarray在 JAX 内部对象中如何工作的一些微妙更改,因为jax.numpy.ndarray现在是jax.Array的简单别名。
- 破坏性变更
jax._src不再导入公共jax命名空间。这可能会打破使用 JAX 内部功能的用户。- 已删除
jax.soft_pmap。请改用pjit或xmap。jax.soft_pmap未记录文档。如果有文档记录,将提供弃用期。
jax 0.3.17(2022 年 8 月 31 日)
- GitHub 提交记录。
- 错误修复
- 修复了
lax.pow的梯度在指数为零时的特殊情况问题(#12041)
- 破坏性变更
jax.checkpoint(),又称jax.remat(),不再支持concrete选项,遵循前一个版本的弃用;请参阅JEP 11830。
- 变更
- 添加了
jax.pure_callback(),允许从编译函数(例如用jax.jit或jax.pmap装饰的函数)调用纯 Python 函数。
- 弃用:
- 已移除不推荐使用的
DeviceArray.tile()方法。使用jax.numpy.tile()代替(#11944)。 - 已弃用
DeviceArray.to_py()。请改用np.asarray(x)。
jax 0.3.16
- GitHub 提交记录。
- 破坏性变更
- 支持 NumPy 1.19 已被移除,根据弃用政策。请升级到 NumPy 1.20 或更新版本。
- 变更
- 添加了
jax.debug,包括用于运行时值调试的实用程序,如jax.debug.print()和jax.debug.breakpoint()。 - 添加了用于运行时值调试的新文档
- 弃用
- 移除了
jax.mask()和jax.shapecheck()API。详见#11557。 - 移除了
jax.experimental.loops。可查看#10278获取替代 API。 jax.tree_util.tree_multimap()已移除。自 JAX 版本 0.3.5 起已被弃用,jax.tree_util.tree_map()是直接替换。- 删除了
jax.experimental.stax;它长期以来一直是jax.example_libraries.stax的弃用别名。 - 移除了
jax.experimental.optimizers;它长期以来一直是jax.example_libraries.optimizers的弃用别名。 jax.checkpoint(),又称jax.remat(),有了一个新的默认实现,意味着旧的实现已被弃用;请参阅JEP 11830。
JAX 中文文档(十六)(4)https://developer.aliyun.com/article/1559730