JAX 中文文档(十六)(1)https://developer.aliyun.com/article/1559726
变更日志
最佳查看此处。
jax 0.4.31
jaxlib 0.4.31
- Bug 修复
- 修复了一个 bug,导致 jit 在快速路径中错误处理负的静态参数。
jax 0.4.30(2024 年 6 月 18 日)
- 变更
- JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升到 0.4.0,但此次发布已回滚,以便 TensorFlow 和 JAX 的用户有足够时间迁移到更新的 TensorFlow 版本。
jax.experimental.mesh_utils
现在可以为 TPU v5e 创建高效的网格。- 现在,jax 直接依赖于 jaxlib。这一变更由 CUDA 插件开关驱动:不再存在多个 jaxlib 变体。您可以通过
pip install jax
安装仅支持 CPU 的 jax,无需额外的内容。 - 添加了导出和序列化 JAX 函数的 API。此功能曾存在于
jax.experimental.export
中(正在弃用),现在将位于jax.export
中。请参阅文档。
- 弃用信息
- 内部漂亮打印工具
jax.core.pp_*
已弃用,并将在将来的版本中移除。 - 对追踪器的哈希化已弃用,并将在未来的 JAX 版本中导致
TypeError
。这在先前的 JAX 版本中是一种情况,但在最近几个 JAX 版本中出现了意外的退化。 jax.experimental.export
已弃用。请改用jax.export
。参见迁移指南。- 在大多数情况下,现在已弃用将数组作为 dtype 的传递方式;例如,对于数组
x
和y
,x.astype(y)
将引发警告。要消除警告,请使用x.astype(y.dtype)
。 jax.xla_computation
已弃用,并将在将来的版本中移除。请使用 AOT API 以获得与jax.xla_computation
相同的功能。
jax.xla_computation(fn)(*args, **kwargs)
可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。- 您还可以使用
jax.stages.Lowered
的.out_info
属性来获取输出信息(例如树结构、形状和 dtype)。 - 对于跨后端的降低,您可以将
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jaxlib 0.4.30(2024 年 6 月 18 日)
- 不再支持单片 CUDA jaxlibs。您必须使用基于插件的安装方式(
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29(2024 年 6 月 10 日)
- 变更
- 我们预计这将是支持单片 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本发布。未来的版本将使用基于插件的 CUDA jaxlib(例如
pip install jax[cuda12]
)。 - JAX 现在要求 ml_dtypes 版本为 0.4.0 或更新。
- 移除了对旧版
jax.experimental.export
API 的向后兼容支持。不再可以使用from jax.experimental.export import export
,而应改为from jax.experimental import export
。已自 0.4.24 版本起弃用该功能。 - 在
jax.tree.all()
和jax.tree_util.tree_all()
中添加了is_leaf
参数。
- 弃用
- 弃用了
jax.sharding.XLACompatibleSharding
。请使用jax.sharding.Sharding
。 jax.experimental.Exported.in_shardings
已重命名为jax.experimental.Exported.in_shardings_hlo
。out_shardings
也是如此。旧名称将在 3 个月后移除。- 移除了一些先前弃用的 API:
- 来自
jax.core
:non_negative_dim
,DimSize
,Shape
- 来自
jax.lax
:tie_in
- 来自
jax.nn
:normalize
- 来自
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
。
jax.numpy.linalg.matrix_rank()
的tol
参数即将弃用并很快将被移除。请改用rtol
。jax.numpy.linalg.pinv()
的rcond
参数即将弃用并很快将被移除。请改用rtol
。- 已移除不推荐使用的
jax.config
子模块。要配置 JAX,请使用import jax
,然后通过jax.config
引用配置对象。 jax.random
API 现在不再接受批量键,先前一些 API 无意中接受了。未来建议在这些情况下显式使用jax.vmap()
。- 在
jax.scipy.special.beta()
中,为了与其他beta
API 保持一致性,已将x
和y
参数重命名为a
和b
。
- 新功能
- 添加了
jax.experimental.Exported.in_shardings_jax()
来构建可以与存储在Exported
对象中的 HloShardings 在 JAX API 中使用的 shardings。
jaxlib 0.4.29(2024 年 6 月 10 日)
- Bug 修复
- 修复了 XLA 不正确分片某些连接操作的 bug,表现为累积归约输出不正确(#21403)。
- 修复了 XLA:CPU 错误编译某些矩阵乘法融合的 bug(https://github.com/openxla/xla/pull/13301)。
- 修复了 GPU 上的编译器崩溃(https://github.com/google/jax/issues/21396)。
- 弃用
jax.tree.map(f, None, non-None)
现在会发出DeprecationWarning
,并且在未来的 jax 版本中将引发错误。None
只是其自身的树前缀。为保留当前行为,您可以请求jax.tree.map
将None
视为叶子值,方法是写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28(2024 年 5 月 9 日)
- Bug 修复
- 撤销了导致 Equinox 失效的
make_jaxpr
更改(#21116)。
- 弃用与移除
jax.numpy.sort()
和jax.numpy.argsort()
的kind
参数现已移除。请改用stable=True
或stable=False
。- 从
jax.experimental.pallas.gpu
模块中移除了get_compute_capability
。请改用由jax.devices()
或jax.local_devices()
返回的 GPU 设备的compute_capability
属性。 jax.numpy.reshape()
的newshape
参数已被弃用,并将很快移除。请改用shape
。
- 变更
- 本版本 jaxlib 的最低版本为 0.4.27。
jaxlib 0.4.28 (2024 年 5 月 9 日)
- Bug 修复
- 修复了在 Python 3.10 或更早版本中的数组和 JIT Python 对象类型名称中的内存损坏 bug。
- 修复了在 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target
。 - 修复了 CPU 上的缓慢编译问题。
- 变更
- 现在的 Windows 构建使用 Clang 而不是 MSVC。
jax 0.4.27 (2024 年 5 月 7 日)
- 新功能
- 新增了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循其在 2023 年标准的数组 API 中的添加,这很快将被 NumPy 采纳。 - 新增了一个新的配置选项
jax_cpu_collectives_implementation
,用于选择 CPU 后端使用的跨进程集合操作的实现。可用选项为'none'
(默认)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果设置为'none'
,则禁用跨进程集合操作。
- 变更
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
现在使用jax.Array
而不是np.ndarray
。您可以通过在传递给回调之前通过jax.tree.map(np.asarray, args)
转换参数来恢复旧的行为。complex_arr.astype(bool)
现在遵循与 NumPy 相同的语义,当complex_arr
等于0 + 0j
时返回 False,否则返回 True。core.Token
现在是一个包装jax.Array
的非平凡类。可以创建并将其传递到计算中,以建立依赖关系。已移除了单例对象core.token
,现在用户应该创建和使用新的core.Token
对象。- 在 GPU 上,默认情况下,Threefry PRNG 实现不再降低为内核调用。这种选择可以在编译时减少运行时内存使用。可以通过
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢复先前的行为,即产生内核调用。如果新的默认行为导致问题,请报告 bug。否则,我们计划在未来的版本中移除此标志。
- 废弃和移除
- Pallas 现在完全采用 XLA 编译 GPU 上的内核。通过 Triton Python API 的旧降低通路已被移除,
JAX_TRITON_COMPILE_VIA_XLA
环境变量不再起作用。 jax.numpy.clip()
现在具有新的参数签名:a
、a_min
和a_max
已被弃用,改用x
(仅位置参数)、min
和max
(#20550)。- JAX 数组的
device()
方法已被移除,自 JAX v0.4.21 弃用后。请改用arr.devices()
。 - 对于
jax.nn.softmax()
和jax.nn.log_softmax()
,initial
参数已弃用;现在支持不设置 softmax 的空输入。 - 在
jax.jit()
中,传递无效的static_argnums
或static_argnames
现在会导致错误,而不是警告。 - 最低的 jaxlib 版本现在是 0.4.23。
jax.numpy.hypot()
函数现在在传递复数输入时会发出弃用警告。在弃用完成时,将会引发错误。- 标量参数传递给
jax.numpy.nonzero()
、jax.numpy.where()
及其相关函数现在会引发错误,这与 NumPy 中的类似变更一致。 - 配置选项
jax_cpu_enable_gloo_collectives
已不推荐使用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。 - 在 JAX v0.4.22 中弃用并移除了
jax.Array.device_buffer
和jax.Array.device_buffers
方法。改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。 jax.numpy.where
的condition
、x
和y
参数现在只能按位置传递,这是在 JAX v0.4.21 中关键字被弃用后的变更。- 现在在
jax.lax.linalg
中函数的非数组参数必须通过关键字指定。之前会引发 DeprecationWarning。 - 现在在几个
jax.numpy
的 API 中(包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
),需要使用类似数组的参数。
- Bug 修复
- 当
copy=True
时,jax.numpy.astype()
现在总是返回一个副本。之前当输出数组的 dtype 与输入数组相同时,不会进行复制。这可能会导致一些内存使用增加。默认值设置为copy=False
以保持向后兼容性。
jaxlib 0.4.27 (2024 年 5 月 7 日)
jax 0.4.26 (2024 年 4 月 3 日)
- 新功能
- 添加了
jax.numpy.trapezoid()
,跟随 NumPy 2.0 中此函数的添加。
- 变更
- 复数值
jax.numpy.geomspace()
现在选择与 NumPy 2.0 一致的对数螺旋分支。 - 在
jax.vmap
下,lax.rng_bit_generator
的行为,以及'rbg'
和'unsafe_rbg'
的 PRNG 实现,已发生变化,使得在密钥上进行映射只会从批处理中的第一个密钥生成随机数。 - 文档现在使用
jax.random.key
构造 PRNG 密钥数组,而不是jax.random.PRNGKey
。
- 弃用和移除
jax.tree_map()
已弃用;请改用jax.tree.map
,或者为了与旧版 JAX 向后兼容性,请使用jax.tree_util.tree_map()
。jax.clear_backends()
因其名字不确保做其名义暗示的操作,可能导致意外后果而被弃用,例如,它不会销毁现有的后端或释放相应的资源。如果只想清理编译缓存,请使用jax.clear_caches()
。为了向后兼容性或者确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
。- 废弃了
jax.experimental.maps
模块和jax.experimental.maps.xmap
。请使用jax.experimental.shard_map
或在表达 SPMD 设备并行计算时使用带有spmd_axis_name
参数的jax.vmap
。 - 废弃了
jax.experimental.host_callback
模块。请改用新的 JAX 外部回调。添加了JAX_HOST_CALLBACK_LEGACY
标志以帮助过渡到新的回调。参见 #20385 进行讨论。 - 将无法转换为 JAX 数组的参数传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
现在会导致异常。 - 移除了废弃标志
jax_parallel_functions_output_gda
。该标志早已废弃且无效;其使用对操作无影响。 - 先前弃用的导入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
现已移除。请改用jax.config
和jax.extend.source_info_util
。 - JAX 导出不再支持旧的序列化版本。自 2023 年 10 月 27 日起支持版本 9,并自 2024 年 2 月 1 日起成为默认版本。详见版本描述。此更改可能会影响将 JAX 序列化版本设置为低于 9 的客户端。
jaxlib 0.4.26(2024 年 4 月 3 日)
- 更改
- JAX 现在仅支持 CUDA 12.1 或更新版本。不再支持 CUDA 11.8。
- JAX 现在支持 NumPy 2.0。
jax 0.4.25(2024 年 2 月 26 日)
- 新功能
- 增加了对 CUDA 数组接口 的导入支持(需要 jaxlib 0.4.24)。
- JAX 数组现在支持 NumPy 风格的标量布尔索引,例如
x[True]
或x[False]
。 - 新增了
jax.tree
模块,提供了更便捷的接口来引用jax.tree_util
中的函数。 jax.tree.transpose()
(即jax.tree_util.tree_transpose()
)现在接受inner_treedef=None
,在这种情况下,内部 treedef 将自动推断。
- 更改
- Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将
JAX_TRITON_COMPILE_VIA_XLA
环境变量设置为"0"
来恢复到旧行为。 jax.interpreters.xla
中几个在 v0.4.24 中移除的废弃 API 在 v0.4.25 中重新添加,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。这些仍被视为废弃,将来会在更好的替代品可用时再次移除。参见 #19816 进行讨论。
- 废弃与移除
jax.numpy.linalg.solve()
现在对于批处理的 1D 解法(b.ndim > 1
)显示废弃警告。将来将将这些视为批处理的 2D 解法。- 将非标量数组转换为 Python 标量现在会引发错误,无论数组的大小如何。在非标量大小为 1 的数组的情况下,之前会引发弃用警告。这与 NumPy 中的类似弃用相似。
- 先前弃用的配置 API 已经根据标准的 3 个月弃用周期被移除(请参见 API 兼容性)。这些包括
jax.config.config
对象和jax.config
的define_*_state
和DEFINE_*
方法。
- 通过
import jax.config
导入jax.config
子模块已经被弃用。配置 JAX 请使用import jax
,然后通过jax.config
引用配置对象。 - 最低的 jaxlib 版本现在是 0.4.20。
jaxlib 0.4.25(2024 年 2 月 26 日)
jax 0.4.24(2024 年 2 月 6 日)
- 变更
- JAX 降级到 StableHLO 不再依赖于物理设备。如果您的原语在降级规则中使用
custom_partitioning
或 JAX 回调,即传递给mlir.register_lowering
的rule
参数的函数,则将原语添加到jax._src.dispatch.prim_requires_devices_during_lowering
集合中。这是因为custom_partitioning
和 JAX 回调需要物理设备在降级过程中创建Sharding
。这是一个临时状态,直到我们可以在没有物理设备的情况下创建Sharding
。 jax.numpy.argsort()
和jax.numpy.sort()
现在支持stable
和descending
参数。- 对形状多态性处理的若干更改(用于
jax.experimental.jax2tf
和jax.experimental.export
中):
- 更清晰地打印符号表达式(#19227)
- 增加了在维度变量上指定符号约束的功能。这使得形状多态性更加表达,并且提供了一个方法来解决不等式推理中的限制。参见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
- 随着符号约束的增加(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能相互作用,例如,在算术操作中。作用域由
jax.experimental.jax2tf.convert()
,jax.experimental.export.symbolic_shape()
,jax.experimental.export.symbolic_args_specs()
引入。符号表达式e
的作用域可以通过e.scope
读取,并传递给上述函数以指导它们在给定作用域中构建符号表达式。请参阅 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。 - 简化和加快等式比较,如果它们的差异的标准化形式减少为 0,则认为两个符号维度相等(#19231;请注意,这可能导致用户可见的行为变化)
- 改进了不确定的不等式比较的错误消息 (#19235)。
core.non_negative_dim
API(最近引入)已弃用,引入了core.max_dim
和core.min_dim
(#18953) 用于表示符号维度的max
和min
。您可以使用core.max_dim(d, 0)
代替core.non_negative_dim(d)
。shape_poly.is_poly_dim
已弃用,改为使用export.is_symbolic_dim
(#19282)。export.args_specs
已弃用,应使用export.symbolic_args_specs ({jax-issue}
#19283)
。shape_poly.PolyShape
和jax2tf.PolyShape
已弃用,应使用字符串来指定多态形状 (#19284)。- JAX 默认的本地序列化版本现在是 9。这对
jax.experimental.jax2tf
和jax.experimental.export
非常重要。请参阅 版本号说明。
- 重构了
jax.experimental.export
的 API。现在应使用from jax.experimental import export
而不是from jax.experimental.export import export
。旧的导入方式将在 3 个月的弃用期后停止支持。 - 添加了
jax.scipy.stats.sem()
。 - 带有
return_inverse = True
的jax.numpy.unique()
返回重塑为输入维度的反向索引,遵循 NumPy 2.0 中类似的更改numpy.unique()
。 jax.numpy.sign()
现在对非零复数输入返回x / abs(x)
。这与 NumPy 2.0 版本中numpy.sign()
的行为一致。- 带有
return_sign=True
的jax.scipy.special.logsumexp()
现在使用 NumPy 2.0 中的复数符号约定x / abs(x)
。这与 SciPy v1.13 中的scipy.special.logsumexp()
的行为一致。 - JAX 现在支持布尔型 DLPack 类型的导入和导出。之前布尔值无法导入,并且以整数形式导出。
- 弃用和移除:
- 删除了许多先前弃用的函数,遵循标准的 3+ 个月弃用周期(请参阅 API 兼容性)。
- 从
jax.core
中移除:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。 - 从
jax.lax
中移除:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。 jax.linear_util
子模块及其所有内容。jax.prng
子模块及其所有内容。- 来自
jax.random
:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。 - 来自
jax.tree_util
:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。 - 来自
jax.interpreters.xla
:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。 - 来自
jax.numpy
:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。 - 来自
jax.scipy.linalg
:tril
和triu
。
- 已弃用的方法
PRNGKeyArray.unsafe_raw_array
已被移除。请使用jax.random.key_data()
替代。 bool(empty_array)
现在引发错误,而不是返回False
。这之前会引发弃用警告,并遵循 NumPy 中类似的更改。- 弃用了对 mhlo MLIR 方言的支持。JAX 不再使用 mhlo 方言,而是改用 stablehlo。将来将删除指称“mhlo”的 API。请改用“stablehlo”方言。
jax.random
:直接将批处理密钥传递给随机数生成函数(如bits()
、gamma()
等)已弃用,并将发出FutureWarning
。请使用jax.vmap
进行显式批处理。- 弃用了
jax.lax.tie_in()
:自 JAX v0.2.0 以来已成为无操作。
jaxlib 0.4.24(2024 年 2 月 6 日)
- 变更
- JAX 现在支持 CUDA 12.3 和 CUDA 11.8。不再支持 CUDA 12.2。
cost_analysis
现在可以与交叉编译的Compiled
对象一起使用(例如,在非 TPU 计算机上使用.lower().compile()
编译为云 TPU 时使用拓扑对象)。- 添加了CUDA 数组接口导入支持(需要 jax 0.4.25)。
jax 0.4.23(2023 年 12 月 13 日)
jaxlib 0.4.23(2023 年 12 月 13 日)
- 修复了导致 GPU 编译器在编译期间产生冗长日志的错误。
jax 0.4.22(2023 年 12 月 13 日)
- 弃用内容
- JAX 数组的
device_buffer
和device_buffers
属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但以前的输出可以通过以下方式恢复:
arr.device_buffer
变为arr.addressable_data(0)
arr.device_buffers
变为[x.data for x in arr.addressable_shards]
jaxlib 0.4.22(2023 年 12 月 13 日)
jax 0.4.21(2023 年 12 月 4 日)
- 新特性
- 添加了
jax.nn.squareplus
。
- 变更
- 最低 jaxlib 版本现在为 0.4.19。
- 现在发布的 Wheels 使用 clang 而不是 gcc 构建。
- 在调用
jax.distributed.initialize()
之前,强制确保设备后端未初始化。 - 在云 TPU 环境中自动化
jax.distributed.initialize()
的参数。
- 弃用内容
- 从
jax.scipy.linalg.solve()
中删除了先前弃用的sym_pos
参数。请改用assume_a='pos'
。 - 将
None
传递给jax.array()
或jax.asarray()
,无论是直接传递还是在列表或元组中传递,已被弃用并现在引发FutureWarning
。当前转换为 NaN,在将来将引发TypeError
。 - 通过关键字参数传递
condition
、x
和y
参数给jax.numpy.where
已被弃用,以匹配numpy.where
。 - 传递给
jax.numpy.array_equal()
和jax.numpy.array_equiv()
的参数如果不能转换为 JAX 数组,则已被弃用并现在引发DeprecationWaning
。当前函数返回 False,在将来将引发异常。 - JAX 数组的
device()
方法已被弃用。根据上下文,可能替换为以下之一:
jax.Array.devices()
返回数组使用的所有设备集。jax.Array.sharding
给出了数组使用的分片配置。
jaxlib 0.4.21 (2023 年 12 月 4 日)
- 变更
- 为了添加分布式 CPU 支持的准备工作,JAX 现在将 CPU 设备与 GPU 和 TPU 设备相同对待,即:
jax.devices()
包括分布式作业中所有设备,即使这些设备不在当前进程中也包括在内。jax.local_devices()
仍然只包括当前进程中的设备,因此如果对jax.devices()
的更改影响到您,您可能更希望使用jax.local_devices()
。- CPU 设备现在在分布式作业中接收全局唯一的 ID 号码;以前 CPU 设备将接收进程本地的 ID 号码。
- 每个 CPU 设备的
process_index
现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前 CPU 设备的process_index
总是 0。
- 在 NVIDIA GPU 上,JAX 现在优先选择 Jacobi SVD 求解器用于大小不超过 1024x1024 的矩阵。与非 Jacobi 版本相比,Jacobi 求解器似乎更快。
- Bug 修复
- 当传递具有非有限值的数组给非对称特征分解时发生错误/挂起(#18226)。现在,具有非有限值的数组将产生由 NaN 组成的输出数组。
jax 0.4.20 (2023 年 11 月 2 日)
jaxlib 0.4.20 (2023 年 11 月 2 日)
- Bug 修复
- 修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆。
jax 0.4.19 (2023 年 10 月 19 日)
- 新功能
- 添加了
jax.typing.DTypeLike
,可用于注释可转换为 JAX 数据类型的对象。 - 添加了
jax.numpy.fill_diagonal
。
- 变更
- 现在 JAX 要求 SciPy 1.9 或更新版本。
- Bug 修复
- 在多控制器分布式 JAX 程序中,只有进程 0 将写入持久编译缓存条目。如果缓存放置在网络文件系统(如 GCS)上,则修复了写入争用问题。
- 当决定已安装的 cusolver 和 cufft 版本是否至少与 JAX 构建的版本一样新时,版本检查现在不再考虑补丁版本。
jaxlib 0.4.19 (2023 年 10 月 19 日)
- 变更
- jaxlib 现在始终优先使用通过 pip 安装的 NVIDIA CUDA 库(nvidia-… packages),而不管
LD_LIBRARY_PATH
中命名的其他 CUDA 安装。如果因此导致问题且意图是使用系统安装的 CUDA,则解决方法是移除 pip 安装的 CUDA 库包。
jax 0.4.18(2023 年 10 月 6 日)
jaxlib 0.4.18(2023 年 10 月 6 日)
- 变更:
- CUDA jaxlibs 现在依赖于用户安装兼容的 NCCL 版本。如果使用推荐的
cuda12_pip
安装,NCCL 应会自动安装。目前需要 NCCL 2.16 或更新版本。 - 现在我们提供 Linux aarch64 wheels,包括带有和不带有 NVIDIA GPU 支持的版本。
jax.Array.item()
现在支持可选的索引参数。
- 弃用:
- 一些
jax.lax
中的内部实用程序和无意导出已被弃用,并将在将来的版本中移除。
jax.lax.dtypes
: 使用jax.dtypes
替代。jax.lax.itertools
:使用itertools
替代。naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是内部实用程序,现在已弃用且没有替代。
- Bug 修复
- 修复了云 TPU 回归,因 smem 导致编译 OOM。
jax 0.4.17(2023 年 10 月 3 日)
- 新功能
- 新增了
jax.numpy.bitwise_count()
函数,与最近添加到 NumPy 的类似函数的 API 匹配。
- 弃用:
- 移除了弃用的模块
jax.abstract_arrays
及其所有内容。 jax.random
中的命名键构造函数已被弃用。改为向jax.random.PRNGKey()
或jax.random.key()
传递impl
参数:
random.threefry2x32_key(seed)
变为random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
变为random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
变为random.PRNGKey(seed, impl='unsafe_rbg')
- 变更:
- CUDA:JAX 现在会验证其找到的 CUDA 库是否至少与 JAX 构建时使用的 CUDA 库一样新。如果发现较旧的库,JAX 将引发异常,因为这比神秘的故障和崩溃更可取。
- 移除了“未找到 GPU/TPU”的警告。而是在 Linux 上,如果发现但未使用 NVIDIA GPU 或 Google TPU,并且未指定
--jax_platforms
,则发出警告。 jax.scipy.stats.mode()
现在在跨尺寸为 0 的轴上取模时返回 0 计数,与 SciPy 1.11 中scipy.stats.mode
的行为相匹配。- 大多数
jax.numpy
函数和属性现在都具有完全定义的类型存根。以前,这些函数中的许多被静态类型检查器(如mypy
和pytype
)视为Any
。
jaxlib 0.4.17(2023 年 10 月 3 日)
- 变更:
- Python 3.12 wheels 已添加到此版本中。
- CUDA 12 wheels 现在需要 CUDA 12.2 或更新版本以及 cuDNN 8.9.4 或更新版本。
- Bug 修复:
- 修复了当 JAX CPU 后端初始化时,ABSL 输出大量日志的问题。
jax 0.4.16(2023 年 9 月 18 日)
- 变更:
- 添加了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以将任何标量值函数转换为类似于numpy.ufunc()
的对象,具有outer()
、reduce()
、accumulate()
、at()
和reduceat()
等方法(#17054)。 - 添加了
jax.scipy.integrate.trapezoid()
。 - 在非 IPython 环境下:当引发异常时,JAX 现在会从回溯中过滤掉其内部帧的整体。(之前会出现“未过滤堆栈跟踪”)。这应该会产生更友好的堆栈跟踪。详见 此处 的示例。此行为可以通过设置
JAX_TRACEBACK_FILTERING=remove_frames
(用于两个单独的未过滤/过滤后的堆栈跟踪,即旧的行为)或JAX_TRACEBACK_FILTERING=off
(用于一个未过滤的堆栈跟踪)来改变。 - jax2tf 默认序列化版本现在是 7,引入了新的形状 安全断言。
- 传递给
jax.sharding.Mesh
的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices()
已经是可哈希的。
- 破坏性变更:
- jax2tf 现在默认使用本地序列化。请查阅 jax2tf 文档 获取详细信息以及覆盖默认设置的机制。
- 选项
--jax_coordination_service
已被移除。现在总是True
。 jax.jaxpr_util
已从公共 JAX 命名空间中移除。JAX_USE_PJRT_C_API_ON_TPU
不再生效(即它总是默认为 true)。- 自 2021 年 12 月引入的兼容性标志
--jax_host_callback_ad_transforms
已被移除。
- 弃用:
- 根据 NumPy NEP-52,几个
jax.numpy
API 已经被弃用:
jax.numpy.NINF
已经被弃用。请改用-jax.numpy.inf
。jax.numpy.PZERO
已经被弃用。请改用0.0
。jax.numpy.NZERO
已经被弃用。请改用-0.0
。jax.numpy.issubsctype(x, t)
已经被弃用。请改用jax.numpy.issubdtype(x.dtype, t)
。jax.numpy.row_stack
已经被弃用。请改用jax.numpy.vstack
。jax.numpy.in1d
已经被弃用。请改用jax.numpy.isin
。jax.numpy.trapz
已经被弃用。请改用jax.scipy.integrate.trapezoid
。
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
已经被弃用,遵循 SciPy 的做法。请改用jax.numpy.tril
和jax.numpy.triu
。jax.lax.prod
已经在 JAX v0.4.11 中被移除,之前已被弃用。请改用内置的math.prod
。- 从
jax.interpreters.xla
导出的一些与为自定义 JAX 原语定义 HLO 降低规则有关的内容已经被弃用。应该使用jax.interpreters.mlir
中的 StableHLO 降低实用工具来定义自定义原语。 - 在经过三个月的弃用期后,以下先前弃用的函数已被移除:
jax.abstract_arrays.ShapedArray
: 使用jax.core.ShapedArray
。jax.abstract_arrays.raise_to_shaped
: 使用jax.core.raise_to_shaped
。jax.numpy.alltrue
: 使用jax.numpy.all
。jax.numpy.sometrue
: 使用jax.numpy.any
。jax.numpy.product
: 使用jax.numpy.prod
。jax.numpy.cumproduct
: 使用jax.numpy.cumprod
。
- 弃用/移除:
- 内部子模块
jax.prng
现已弃用。其内容可在jax.extend.random
中找到。 - 内部子模块路径
jax.linear_util
已被弃用。请使用jax.extend.linear_util
替代(jax.extend 的一部分:扩展模块)。 jax.random.PRNGKeyArray
和jax.random.KeyArray
已弃用。请在类型注释中使用jax.Array
,并在运行时使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
来检测类型化的 PRNG 密钥。- 方法
PRNGKeyArray.unsafe_raw_array
已弃用。请改用jax.random.key_data()
。 jax.experimental.pjit.with_sharding_constraint
已弃用。请使用jax.lax.with_sharding_constraint
替代。- 内部工具函数
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
已被移除。不透明数据类型已重命名为扩展数据类型;请使用jnp.issubdtype(dtype, jax.dtypes.extended)
替代(自 jax v0.4.14 起可用)。 - 实用工具函数
jax.interpreters.xla.register_collective_primitive
已被移除。在最新的 JAX 发行版中,此实用工具无任何作用,可以安全移除其调用。 - 内部子模块路径
jax.linear_util
已被弃用。请使用jax.extend.linear_util
替代(jax.extend 的一部分:扩展模块)。
jaxlib 0.4.16(2023 年 9 月 18 日)
- 变更:
- 在 NVIDIA GPU 上,通过实验性的 jax 稀疏 API 进行的稀疏 CSR 矩阵乘法不再使用确定性算法。此更改是为了与 CUDA 12.2.1 兼容性而进行的。
- Bug 修复:
- 修复了由于关于乱序段和 IMAGE_REL_AMD64_ADDR32NB 重定位的致命 LLVM 错误而在 Windows 上崩溃的问题(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14(2023 年 7 月 27 日)
- 变更:
jax.jit
接受donate_argnames
作为参数。其语义类似于static_argnames
。如果既不提供donate_argnums
也不提供donate_argnames
,则不会捐赠任何参数。如果不提供donate_argnums
但提供了donate_argnames
,或者反之,则 JAX 使用inspect.signature(fun)
来查找与donate_argnames
(或其反向)相对应的任何位置参数。如果同时提供了donate_argnums
和donate_argnames
,则不使用inspect.signature
,并且只有实际参数列在donate_argnums
或donate_argnames
中将被捐赠。jax.random.gamma()
已重新设计为更高效的算法,具有更健壮的端点行为(#16779)。这意味着给定key
的值序列在 JAX v0.4.13 和 v0.4.14 之间的gamma
和相关抽样器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)将发生变化。
- 删除项:
- 自弃用以来已超过 3 个月的
in_axis_resources
和out_axis_resources
已从 pjit 中删除。请使用in_shardings
和out_shardings
进行替换。这是一个安全和简单的名称替换。它不会改变任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将PartitionSpecs
传递给in_shardings
和out_shardings
。
- 弃用项:
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,已删除对 Python 3.8 的支持。
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,JAX 现在要求 NumPy 1.22 或更新版本。
- 不再支持通过位置向
jax.numpy.ndarray.at()
传递可选参数,已在 JAX 版本 0.4.7 中被弃用。例如,不再使用x.at[i].get(True)
,而是使用x.at[i].get(indices_are_sorted=True)
。 - 以下
jax.Array
方法在 JAX v0.4.5 中被弃用后已被移除:
jax.Array.broadcast
: 改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
: 改用jax.lax.broadcast_in_dim()
。jax.Array.split
: 使用jax.numpy.split()
替代。
- 在之前的弃用之后,以下 API 已被移除:
jax.ad
: 使用jax.interpreters.ad
。jax.curry
: 使用curry = lambda f: partial(partial, f)
。jax.partial_eval
: 使用jax.interpreters.partial_eval
。jax.pxla
: 使用jax.interpreters.pxla
。jax.xla
: 使用jax.interpreters.xla
。jax.ShapedArray
: 使用jax.core.ShapedArray
。jax.interpreters.pxla.device_put
: 使用jax.device_put()
。jax.interpreters.pxla.make_sharded_device_array
: 使用jax.make_array_from_single_device_arrays()
。jax.interpreters.pxla.ShardedDeviceArray
: 使用jax.Array
。jax.numpy.DeviceArray
: 使用jax.Array
。jax.stages.Compiled.compiler_ir
: 使用jax.stages.Compiled.as_text()
。
- 破坏性变更:
- JAX 现在要求 ml_dtypes 版本 0.2.0 或更新版本。
- 为了修复一个边缘情况,调用
jax.lax.cond()
时,如果第二个和第三个参数是可调用的,则使用五个参数总是解析为文档中记录的 “common operands”cond
行为,即使其他操作数也是可调用的。参见 #16413。 - 已删除无效配置选项
jax_array
和jax_jit_pjit_api_merge
。这些选项默认情况下自许多版本以来都为 true。
- 新功能:
- JAX 现在支持配置标志
--jax_serialization_version
和环境变量JAX_SERIALIZATION_VERSION
来控制序列化版本(#16746)。 - 在形状多态性存在的情况下,jax2tf 现在生成检查某些形状约束的代码,如果序列化版本至少为 7。详见 https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
JAX 中文文档(十六)(3)https://developer.aliyun.com/article/1559729