JAX 中文文档(十六)(2)

简介: JAX 中文文档(十六)

JAX 中文文档(十六)(1)https://developer.aliyun.com/article/1559726


变更日志

原文:jax.readthedocs.io/en/latest/changelog.html

最佳查看此处

jax 0.4.31

jaxlib 0.4.31

  • Bug 修复
  • 修复了一个 bug,导致 jit 在快速路径中错误处理负的静态参数。

jax 0.4.30(2024 年 6 月 18 日)

  • 变更
  • JAX 支持 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已提升到 0.4.0,但此次发布已回滚,以便 TensorFlow 和 JAX 的用户有足够时间迁移到更新的 TensorFlow 版本。
  • jax.experimental.mesh_utils现在可以为 TPU v5e 创建高效的网格。
  • 现在,jax 直接依赖于 jaxlib。这一变更由 CUDA 插件开关驱动:不再存在多个 jaxlib 变体。您可以通过pip install jax安装仅支持 CPU 的 jax,无需额外的内容。
  • 添加了导出和序列化 JAX 函数的 API。此功能曾存在于jax.experimental.export中(正在弃用),现在将位于jax.export中。请参阅文档
  • 弃用信息
  • 内部漂亮打印工具jax.core.pp_*已弃用,并将在将来的版本中移除。
  • 对追踪器的哈希化已弃用,并将在未来的 JAX 版本中导致TypeError。这在先前的 JAX 版本中是一种情况,但在最近几个 JAX 版本中出现了意外的退化。
  • jax.experimental.export已弃用。请改用jax.export。参见迁移指南
  • 在大多数情况下,现在已弃用将数组作为 dtype 的传递方式;例如,对于数组xyx.astype(y)将引发警告。要消除警告,请使用x.astype(y.dtype)
  • jax.xla_computation已弃用,并将在将来的版本中移除。请使用 AOT API 以获得与jax.xla_computation相同的功能。
  • jax.xla_computation(fn)(*args, **kwargs)可以替换为jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
  • 您还可以使用jax.stages.Lowered.out_info属性来获取输出信息(例如树结构、形状和 dtype)。
  • 对于跨后端的降低,您可以将jax.xla_computation(fn, backend='tpu')(*args, **kwargs)替换为jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

jaxlib 0.4.30(2024 年 6 月 18 日)

  • 不再支持单片 CUDA jaxlibs。您必须使用基于插件的安装方式(pip install jax[cuda12]pip install jax[cuda12_local])。

jax 0.4.29(2024 年 6 月 10 日)

  • 变更
  • 我们预计这将是支持单片 CUDA jaxlib 的 JAX 和 jaxlib 的最后一个版本发布。未来的版本将使用基于插件的 CUDA jaxlib(例如pip install jax[cuda12])。
  • JAX 现在要求 ml_dtypes 版本为 0.4.0 或更新。
  • 移除了对旧版 jax.experimental.export API 的向后兼容支持。不再可以使用 from jax.experimental.export import export,而应改为 from jax.experimental import export。已自 0.4.24 版本起弃用该功能。
  • jax.tree.all()jax.tree_util.tree_all() 中添加了 is_leaf 参数。
  • 弃用
  • 弃用了 jax.sharding.XLACompatibleSharding。请使用 jax.sharding.Sharding
  • jax.experimental.Exported.in_shardings 已重命名为 jax.experimental.Exported.in_shardings_hloout_shardings 也是如此。旧名称将在 3 个月后移除。
  • 移除了一些先前弃用的 API:
  • 来自 jax.corenon_negative_dimDimSizeShape
  • 来自 jax.laxtie_in
  • 来自 jax.nnnormalize
  • 来自 jax.interpreters.xlabackend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXlaOp
  • jax.numpy.linalg.matrix_rank()tol 参数即将弃用并很快将被移除。请改用 rtol
  • jax.numpy.linalg.pinv()rcond 参数即将弃用并很快将被移除。请改用 rtol
  • 已移除不推荐使用的 jax.config 子模块。要配置 JAX,请使用 import jax,然后通过 jax.config 引用配置对象。
  • jax.random API 现在不再接受批量键,先前一些 API 无意中接受了。未来建议在这些情况下显式使用 jax.vmap()
  • jax.scipy.special.beta() 中,为了与其他 beta API 保持一致性,已将 xy 参数重命名为 ab
  • 新功能
  • 添加了 jax.experimental.Exported.in_shardings_jax() 来构建可以与存储在 Exported 对象中的 HloShardings 在 JAX API 中使用的 shardings。

jaxlib 0.4.29(2024 年 6 月 10 日)

  • Bug 修复
  • 弃用
  • jax.tree.map(f, None, non-None) 现在会发出 DeprecationWarning,并且在未来的 jax 版本中将引发错误。None 只是其自身的树前缀。为保留当前行为,您可以请求 jax.tree.mapNone 视为叶子值,方法是写:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

jax 0.4.28(2024 年 5 月 9 日)

  • Bug 修复
  • 撤销了导致 Equinox 失效的 make_jaxpr 更改(#21116)。
  • 弃用与移除
  • jax.numpy.sort()jax.numpy.argsort()kind 参数现已移除。请改用 stable=Truestable=False
  • jax.experimental.pallas.gpu 模块中移除了 get_compute_capability。请改用由 jax.devices()jax.local_devices() 返回的 GPU 设备的 compute_capability 属性。
  • jax.numpy.reshape()newshape 参数已被弃用,并将很快移除。请改用 shape
  • 变更
  • 本版本 jaxlib 的最低版本为 0.4.27。

jaxlib 0.4.28 (2024 年 5 月 9 日)

  • Bug 修复
  • 修复了在 Python 3.10 或更早版本中的数组和 JIT Python 对象类型名称中的内存损坏 bug。
  • 修复了在 CUDA 12.4 下的警告 '+ptx84' is not a recognized feature for this target
  • 修复了 CPU 上的缓慢编译问题。
  • 变更
  • 现在的 Windows 构建使用 Clang 而不是 MSVC。

jax 0.4.27 (2024 年 5 月 7 日)

  • 新功能
  • 新增了 jax.numpy.unstack()jax.numpy.cumulative_sum(),遵循其在 2023 年标准的数组 API 中的添加,这很快将被 NumPy 采纳。
  • 新增了一个新的配置选项 jax_cpu_collectives_implementation,用于选择 CPU 后端使用的跨进程集合操作的实现。可用选项为 'none'(默认)、'gloo''mpi'(需要 jaxlib 0.4.26)。如果设置为 'none',则禁用跨进程集合操作。
  • 变更
  • jax.pure_callback()jax.experimental.io_callback()jax.debug.callback() 现在使用 jax.Array 而不是 np.ndarray。您可以通过在传递给回调之前通过 jax.tree.map(np.asarray, args) 转换参数来恢复旧的行为。
  • complex_arr.astype(bool) 现在遵循与 NumPy 相同的语义,当 complex_arr 等于 0 + 0j 时返回 False,否则返回 True。
  • core.Token 现在是一个包装 jax.Array 的非平凡类。可以创建并将其传递到计算中,以建立依赖关系。已移除了单例对象 core.token,现在用户应该创建和使用新的 core.Token 对象。
  • 在 GPU 上,默认情况下,Threefry PRNG 实现不再降低为内核调用。这种选择可以在编译时减少运行时内存使用。可以通过 jax.config.update('jax_threefry_gpu_kernel_lowering', True) 恢复先前的行为,即产生内核调用。如果新的默认行为导致问题,请报告 bug。否则,我们计划在未来的版本中移除此标志。
  • 废弃和移除
  • Pallas 现在完全采用 XLA 编译 GPU 上的内核。通过 Triton Python API 的旧降低通路已被移除,JAX_TRITON_COMPILE_VIA_XLA 环境变量不再起作用。
  • jax.numpy.clip() 现在具有新的参数签名:aa_mina_max 已被弃用,改用 x(仅位置参数)、minmax#20550)。
  • JAX 数组的 device() 方法已被移除,自 JAX v0.4.21 弃用后。请改用 arr.devices()
  • 对于jax.nn.softmax()jax.nn.log_softmax()initial参数已弃用;现在支持不设置 softmax 的空输入。
  • jax.jit()中,传递无效的static_argnumsstatic_argnames现在会导致错误,而不是警告。
  • 最低的 jaxlib 版本现在是 0.4.23。
  • jax.numpy.hypot()函数现在在传递复数输入时会发出弃用警告。在弃用完成时,将会引发错误。
  • 标量参数传递给jax.numpy.nonzero()jax.numpy.where()及其相关函数现在会引发错误,这与 NumPy 中的类似变更一致。
  • 配置选项jax_cpu_enable_gloo_collectives已不推荐使用。请改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
  • 在 JAX v0.4.22 中弃用并移除了jax.Array.device_bufferjax.Array.device_buffers方法。改用jax.Array.addressable_shardsjax.Array.addressable_data()
  • jax.numpy.whereconditionxy参数现在只能按位置传递,这是在 JAX v0.4.21 中关键字被弃用后的变更。
  • 现在在jax.lax.linalg中函数的非数组参数必须通过关键字指定。之前会引发 DeprecationWarning。
  • 现在在几个jax.numpy的 API 中(包括apply_along_axis()apply_over_axes()inner()outer()cross()kron()lexsort()),需要使用类似数组的参数。
  • Bug 修复
  • copy=True时,jax.numpy.astype()现在总是返回一个副本。之前当输出数组的 dtype 与输入数组相同时,不会进行复制。这可能会导致一些内存使用增加。默认值设置为copy=False以保持向后兼容性。

jaxlib 0.4.27 (2024 年 5 月 7 日)

jax 0.4.26 (2024 年 4 月 3 日)

  • 新功能
  • 添加了jax.numpy.trapezoid(),跟随 NumPy 2.0 中此函数的添加。
  • 变更
  • 复数值jax.numpy.geomspace()现在选择与 NumPy 2.0 一致的对数螺旋分支。
  • jax.vmap下,lax.rng_bit_generator的行为,以及'rbg''unsafe_rbg'的 PRNG 实现,已发生变化,使得在密钥上进行映射只会从批处理中的第一个密钥生成随机数。
  • 文档现在使用jax.random.key构造 PRNG 密钥数组,而不是jax.random.PRNGKey
  • 弃用和移除
  • jax.tree_map()已弃用;请改用jax.tree.map,或者为了与旧版 JAX 向后兼容性,请使用jax.tree_util.tree_map()
  • jax.clear_backends()因其名字不确保做其名义暗示的操作,可能导致意外后果而被弃用,例如,它不会销毁现有的后端或释放相应的资源。如果只想清理编译缓存,请使用jax.clear_caches()。为了向后兼容性或者确实需要切换/重新初始化默认后端,请使用jax.extend.backend.clear_backends()
  • 废弃了 jax.experimental.maps 模块和 jax.experimental.maps.xmap。请使用 jax.experimental.shard_map 或在表达 SPMD 设备并行计算时使用带有 spmd_axis_name 参数的 jax.vmap
  • 废弃了 jax.experimental.host_callback 模块。请改用新的 JAX 外部回调。添加了 JAX_HOST_CALLBACK_LEGACY 标志以帮助过渡到新的回调。参见 #20385 进行讨论。
  • 将无法转换为 JAX 数组的参数传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 现在会导致异常。
  • 移除了废弃标志 jax_parallel_functions_output_gda。该标志早已废弃且无效;其使用对操作无影响。
  • 先前弃用的导入 jax.interpreters.ad.configjax.interpreters.ad.source_info_util 现已移除。请改用 jax.configjax.extend.source_info_util
  • JAX 导出不再支持旧的序列化版本。自 2023 年 10 月 27 日起支持版本 9,并自 2024 年 2 月 1 日起成为默认版本。详见版本描述。此更改可能会影响将 JAX 序列化版本设置为低于 9 的客户端。

jaxlib 0.4.26(2024 年 4 月 3 日)

  • 更改
  • JAX 现在仅支持 CUDA 12.1 或更新版本。不再支持 CUDA 11.8。
  • JAX 现在支持 NumPy 2.0。

jax 0.4.25(2024 年 2 月 26 日)

  • 新功能
  • 增加了对 CUDA 数组接口 的导入支持(需要 jaxlib 0.4.24)。
  • JAX 数组现在支持 NumPy 风格的标量布尔索引,例如 x[True]x[False]
  • 新增了 jax.tree 模块,提供了更便捷的接口来引用 jax.tree_util 中的函数。
  • jax.tree.transpose()(即 jax.tree_util.tree_transpose())现在接受 inner_treedef=None,在这种情况下,内部 treedef 将自动推断。
  • 更改
  • Pallas 现在使用 XLA 而不是 Triton Python API 来编译 Triton 内核。您可以通过将 JAX_TRITON_COMPILE_VIA_XLA 环境变量设置为 "0" 来恢复到旧行为。
  • jax.interpreters.xla 中几个在 v0.4.24 中移除的废弃 API 在 v0.4.25 中重新添加,包括 backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXLAOp。这些仍被视为废弃,将来会在更好的替代品可用时再次移除。参见 #19816 进行讨论。
  • 废弃与移除
  • jax.numpy.linalg.solve() 现在对于批处理的 1D 解法(b.ndim > 1)显示废弃警告。将来将将这些视为批处理的 2D 解法。
  • 将非标量数组转换为 Python 标量现在会引发错误,无论数组的大小如何。在非标量大小为 1 的数组的情况下,之前会引发弃用警告。这与 NumPy 中的类似弃用相似。
  • 先前弃用的配置 API 已经根据标准的 3 个月弃用周期被移除(请参见 API 兼容性)。这些包括
  • jax.config.config对象和
  • jax.configdefine_*_stateDEFINE_*方法。
  • 通过import jax.config导入jax.config子模块已经被弃用。配置 JAX 请使用import jax,然后通过jax.config引用配置对象。
  • 最低的 jaxlib 版本现在是 0.4.20。

jaxlib 0.4.25(2024 年 2 月 26 日)

jax 0.4.24(2024 年 2 月 6 日)

  • 变更
  • JAX 降级到 StableHLO 不再依赖于物理设备。如果您的原语在降级规则中使用custom_partitioning或 JAX 回调,即传递给mlir.register_loweringrule参数的函数,则将原语添加到jax._src.dispatch.prim_requires_devices_during_lowering集合中。这是因为custom_partitioning和 JAX 回调需要物理设备在降级过程中创建Sharding。这是一个临时状态,直到我们可以在没有物理设备的情况下创建Sharding
  • jax.numpy.argsort()jax.numpy.sort()现在支持stabledescending参数。
  • 对形状多态性处理的若干更改(用于jax.experimental.jax2tfjax.experimental.export中):
  • 更清晰地打印符号表达式(#19227
  • 增加了在维度变量上指定符号约束的功能。这使得形状多态性更加表达,并且提供了一个方法来解决不等式推理中的限制。参见   https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints
  • 随着符号约束的增加(#19235),我们现在认为来自不同作用域的维度变量是不同的,即使它们具有相同的名称。来自不同作用域的符号表达式不能相互作用,例如,在算术操作中。作用域由jax.experimental.jax2tf.convert()jax.experimental.export.symbolic_shape()jax.experimental.export.symbolic_args_specs()引入。符号表达式e的作用域可以通过e.scope读取,并传递给上述函数以指导它们在给定作用域中构建符号表达式。请参阅   https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints
  • 简化和加快等式比较,如果它们的差异的标准化形式减少为 0,则认为两个符号维度相等(#19231;请注意,这可能导致用户可见的行为变化)
  • 改进了不确定的不等式比较的错误消息 (#19235)。
  • core.non_negative_dim API(最近引入)已弃用,引入了 core.max_dimcore.min_dim (#18953) 用于表示符号维度的 maxmin。您可以使用 core.max_dim(d, 0) 代替 core.non_negative_dim(d)
  • shape_poly.is_poly_dim 已弃用,改为使用 export.is_symbolic_dim (#19282)。
  • export.args_specs 已弃用,应使用 export.symbolic_args_specs ({jax-issue}#19283)
  • shape_poly.PolyShapejax2tf.PolyShape 已弃用,应使用字符串来指定多态形状 (#19284)。
  • JAX 默认的本地序列化版本现在是 9。这对 jax.experimental.jax2tfjax.experimental.export 非常重要。请参阅 版本号说明
  • 重构了 jax.experimental.export 的 API。现在应使用 from jax.experimental import export 而不是 from jax.experimental.export import export。旧的导入方式将在 3 个月的弃用期后停止支持。
  • 添加了 jax.scipy.stats.sem()
  • 带有 return_inverse = Truejax.numpy.unique() 返回重塑为输入维度的反向索引,遵循 NumPy 2.0 中类似的更改 numpy.unique()
  • jax.numpy.sign() 现在对非零复数输入返回 x / abs(x)。这与 NumPy 2.0 版本中 numpy.sign() 的行为一致。
  • 带有 return_sign=Truejax.scipy.special.logsumexp() 现在使用 NumPy 2.0 中的复数符号约定 x / abs(x)。这与 SciPy v1.13 中的 scipy.special.logsumexp() 的行为一致。
  • JAX 现在支持布尔型 DLPack 类型的导入和导出。之前布尔值无法导入,并且以整数形式导出。
  • 弃用和移除:
  • 删除了许多先前弃用的函数,遵循标准的 3+ 个月弃用周期(请参阅 API 兼容性)。
  • jax.core 中移除:TracerArrayConversionErrorTracerIntegerConversionErrorUnexpectedTracerErroras_hashable_functioncollectionsdtypeslumapnamedtuplepartialpprefsafe_zipsafe_mapsource_info_utiltotal_orderingtraceback_utiltuple_deletetuple_insertzip
  • jax.lax 中移除:dtypesitertoolsnaryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule
  • jax.linear_util 子模块及其所有内容。
  • jax.prng 子模块及其所有内容。
  • 来自 jax.randomPRNGKeyArrayKeyArraydefault_prng_implthreefry_2x32threefry2x32_keythreefry2x32_prbg_keyunsafe_rbg_key
  • 来自 jax.tree_utilregister_keypathsAttributeKeyPathEntryGetItemKeyPathEntry
  • 来自 jax.interpreters.xlabackend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextaxis_groupsShapedArrayConcreteArrayAxisEnvbackend_compileXLAOp
  • 来自 jax.numpyNINFNZEROPZEROrow_stackissubsctypetrapzin1d
  • 来自 jax.scipy.linalgtriltriu
  • 已弃用的方法 PRNGKeyArray.unsafe_raw_array 已被移除。请使用 jax.random.key_data() 替代。
  • bool(empty_array) 现在引发错误,而不是返回 False。这之前会引发弃用警告,并遵循 NumPy 中类似的更改。
  • 弃用了对 mhlo MLIR 方言的支持。JAX 不再使用 mhlo 方言,而是改用 stablehlo。将来将删除指称“mhlo”的 API。请改用“stablehlo”方言。
  • jax.random:直接将批处理密钥传递给随机数生成函数(如 bits()gamma() 等)已弃用,并将发出 FutureWarning。请使用 jax.vmap 进行显式批处理。
  • 弃用了 jax.lax.tie_in():自 JAX v0.2.0 以来已成为无操作。

jaxlib 0.4.24(2024 年 2 月 6 日)

  • 变更
  • JAX 现在支持 CUDA 12.3 和 CUDA 11.8。不再支持 CUDA 12.2。
  • cost_analysis 现在可以与交叉编译的 Compiled 对象一起使用(例如,在非 TPU 计算机上使用 .lower().compile() 编译为云 TPU 时使用拓扑对象)。
  • 添加了CUDA 数组接口导入支持(需要 jax 0.4.25)。

jax 0.4.23(2023 年 12 月 13 日)

jaxlib 0.4.23(2023 年 12 月 13 日)

  • 修复了导致 GPU 编译器在编译期间产生冗长日志的错误。

jax 0.4.22(2023 年 12 月 13 日)

  • 弃用内容
  • JAX 数组的 device_bufferdevice_buffers 属性已弃用。显式缓冲区已被更灵活的数组分片接口取代,但以前的输出可以通过以下方式恢复:
  • arr.device_buffer 变为 arr.addressable_data(0)
  • arr.device_buffers 变为 [x.data for x in arr.addressable_shards]

jaxlib 0.4.22(2023 年 12 月 13 日)

jax 0.4.21(2023 年 12 月 4 日)

  • 新特性
  • 添加了 jax.nn.squareplus
  • 变更
  • 最低 jaxlib 版本现在为 0.4.19。
  • 现在发布的 Wheels 使用 clang 而不是 gcc 构建。
  • 在调用 jax.distributed.initialize() 之前,强制确保设备后端未初始化。
  • 在云 TPU 环境中自动化 jax.distributed.initialize() 的参数。
  • 弃用内容
  • jax.scipy.linalg.solve() 中删除了先前弃用的 sym_pos 参数。请改用 assume_a='pos'
  • None 传递给 jax.array()jax.asarray(),无论是直接传递还是在列表或元组中传递,已被弃用并现在引发 FutureWarning。当前转换为 NaN,在将来将引发 TypeError
  • 通过关键字参数传递 conditionxy 参数给 jax.numpy.where 已被弃用,以匹配 numpy.where
  • 传递给 jax.numpy.array_equal()jax.numpy.array_equiv() 的参数如果不能转换为 JAX 数组,则已被弃用并现在引发 DeprecationWaning。当前函数返回 False,在将来将引发异常。
  • JAX 数组的 device() 方法已被弃用。根据上下文,可能替换为以下之一:
  • jax.Array.devices() 返回数组使用的所有设备集。
  • jax.Array.sharding 给出了数组使用的分片配置。

jaxlib 0.4.21 (2023 年 12 月 4 日)

  • 变更
  • 为了添加分布式 CPU 支持的准备工作,JAX 现在将 CPU 设备与 GPU 和 TPU 设备相同对待,即:
  • jax.devices() 包括分布式作业中所有设备,即使这些设备不在当前进程中也包括在内。jax.local_devices() 仍然只包括当前进程中的设备,因此如果对 jax.devices() 的更改影响到您,您可能更希望使用 jax.local_devices()
  • CPU 设备现在在分布式作业中接收全局唯一的 ID 号码;以前 CPU 设备将接收进程本地的 ID 号码。
  • 每个 CPU 设备的 process_index 现在将与同一进程中的任何 GPU 或 TPU 设备匹配;以前 CPU 设备的 process_index 总是 0。
  • 在 NVIDIA GPU 上,JAX 现在优先选择 Jacobi SVD 求解器用于大小不超过 1024x1024 的矩阵。与非 Jacobi 版本相比,Jacobi 求解器似乎更快。
  • Bug 修复
  • 当传递具有非有限值的数组给非对称特征分解时发生错误/挂起(#18226)。现在,具有非有限值的数组将产生由 NaN 组成的输出数组。

jax 0.4.20 (2023 年 11 月 2 日)

jaxlib 0.4.20 (2023 年 11 月 2 日)

  • Bug 修复
  • 修复了 E4M3 和 E5M2 float8 类型之间的一些类型混淆。

jax 0.4.19 (2023 年 10 月 19 日)

  • 新功能
  • 添加了 jax.typing.DTypeLike,可用于注释可转换为 JAX 数据类型的对象。
  • 添加了 jax.numpy.fill_diagonal
  • 变更
  • 现在 JAX 要求 SciPy 1.9 或更新版本。
  • Bug 修复
  • 在多控制器分布式 JAX 程序中,只有进程 0 将写入持久编译缓存条目。如果缓存放置在网络文件系统(如 GCS)上,则修复了写入争用问题。
  • 当决定已安装的 cusolver 和 cufft 版本是否至少与 JAX 构建的版本一样新时,版本检查现在不再考虑补丁版本。

jaxlib 0.4.19 (2023 年 10 月 19 日)

  • 变更
  • jaxlib 现在始终优先使用通过 pip 安装的 NVIDIA CUDA 库(nvidia-… packages),而不管 LD_LIBRARY_PATH 中命名的其他 CUDA 安装。如果因此导致问题且意图是使用系统安装的 CUDA,则解决方法是移除 pip 安装的 CUDA 库包。

jax 0.4.18(2023 年 10 月 6 日)

jaxlib 0.4.18(2023 年 10 月 6 日)

  • 变更:
  • CUDA jaxlibs 现在依赖于用户安装兼容的 NCCL 版本。如果使用推荐的 cuda12_pip 安装,NCCL 应会自动安装。目前需要 NCCL 2.16 或更新版本。
  • 现在我们提供 Linux aarch64 wheels,包括带有和不带有 NVIDIA GPU 支持的版本。
  • jax.Array.item() 现在支持可选的索引参数。
  • 弃用:
  • 一些jax.lax中的内部实用程序和无意导出已被弃用,并将在将来的版本中移除。
  • jax.lax.dtypes: 使用 jax.dtypes 替代。
  • jax.lax.itertools:使用 itertools 替代。
  • naryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule 是内部实用程序,现在已弃用且没有替代。
  • Bug 修复
  • 修复了云 TPU 回归,因 smem 导致编译 OOM。

jax 0.4.17(2023 年 10 月 3 日)

  • 新功能
  • 新增了 jax.numpy.bitwise_count() 函数,与最近添加到 NumPy 的类似函数的 API 匹配。
  • 弃用:
  • 移除了弃用的模块 jax.abstract_arrays 及其所有内容。
  • jax.random 中的命名键构造函数已被弃用。改为向 jax.random.PRNGKey()jax.random.key() 传递 impl 参数:
  • random.threefry2x32_key(seed) 变为 random.PRNGKey(seed, impl='threefry2x32')
  • random.rbg_key(seed) 变为 random.PRNGKey(seed, impl='rbg')
  • random.unsafe_rbg_key(seed) 变为 random.PRNGKey(seed, impl='unsafe_rbg')
  • 变更:
  • CUDA:JAX 现在会验证其找到的 CUDA 库是否至少与 JAX 构建时使用的 CUDA 库一样新。如果发现较旧的库,JAX 将引发异常,因为这比神秘的故障和崩溃更可取。
  • 移除了“未找到 GPU/TPU”的警告。而是在 Linux 上,如果发现但未使用 NVIDIA GPU 或 Google TPU,并且未指定 --jax_platforms,则发出警告。
  • jax.scipy.stats.mode() 现在在跨尺寸为 0 的轴上取模时返回 0 计数,与 SciPy 1.11 中 scipy.stats.mode 的行为相匹配。
  • 大多数 jax.numpy 函数和属性现在都具有完全定义的类型存根。以前,这些函数中的许多被静态类型检查器(如 mypypytype)视为 Any

jaxlib 0.4.17(2023 年 10 月 3 日)

  • 变更:
  • Python 3.12 wheels 已添加到此版本中。
  • CUDA 12 wheels 现在需要 CUDA 12.2 或更新版本以及 cuDNN 8.9.4 或更新版本。
  • Bug 修复:
  • 修复了当 JAX CPU 后端初始化时,ABSL 输出大量日志的问题。

jax 0.4.16(2023 年 9 月 18 日)

  • 变更:
  • 添加了 jax.numpy.ufunc,以及 jax.numpy.frompyfunc(),它可以将任何标量值函数转换为类似于 numpy.ufunc() 的对象,具有 outer()reduce()accumulate()at()reduceat() 等方法(#17054)。
  • 添加了 jax.scipy.integrate.trapezoid()
  • 在非 IPython 环境下:当引发异常时,JAX 现在会从回溯中过滤掉其内部帧的整体。(之前会出现“未过滤堆栈跟踪”)。这应该会产生更友好的堆栈跟踪。详见 此处 的示例。此行为可以通过设置 JAX_TRACEBACK_FILTERING=remove_frames(用于两个单独的未过滤/过滤后的堆栈跟踪,即旧的行为)或 JAX_TRACEBACK_FILTERING=off(用于一个未过滤的堆栈跟踪)来改变。
  • jax2tf 默认序列化版本现在是 7,引入了新的形状 安全断言
  • 传递给 jax.sharding.Mesh 的设备应该是可哈希的。这特别适用于模拟设备或用户创建的设备。jax.devices() 已经是可哈希的。
  • 破坏性变更:
  • jax2tf 现在默认使用本地序列化。请查阅 jax2tf 文档 获取详细信息以及覆盖默认设置的机制。
  • 选项 --jax_coordination_service 已被移除。现在总是 True
  • jax.jaxpr_util 已从公共 JAX 命名空间中移除。
  • JAX_USE_PJRT_C_API_ON_TPU 不再生效(即它总是默认为 true)。
  • 自 2021 年 12 月引入的兼容性标志 --jax_host_callback_ad_transforms 已被移除。
  • 弃用:
  • 根据 NumPy NEP-52,几个 jax.numpy API 已经被弃用:
  • jax.numpy.NINF 已经被弃用。请改用 -jax.numpy.inf
  • jax.numpy.PZERO 已经被弃用。请改用 0.0
  • jax.numpy.NZERO 已经被弃用。请改用 -0.0
  • jax.numpy.issubsctype(x, t) 已经被弃用。请改用 jax.numpy.issubdtype(x.dtype, t)
  • jax.numpy.row_stack 已经被弃用。请改用 jax.numpy.vstack
  • jax.numpy.in1d 已经被弃用。请改用 jax.numpy.isin
  • jax.numpy.trapz 已经被弃用。请改用 jax.scipy.integrate.trapezoid
  • jax.scipy.linalg.triljax.scipy.linalg.triu 已经被弃用,遵循 SciPy 的做法。请改用 jax.numpy.triljax.numpy.triu
  • jax.lax.prod 已经在 JAX v0.4.11 中被移除,之前已被弃用。请改用内置的 math.prod
  • jax.interpreters.xla 导出的一些与为自定义 JAX 原语定义 HLO 降低规则有关的内容已经被弃用。应该使用 jax.interpreters.mlir 中的 StableHLO 降低实用工具来定义自定义原语。
  • 在经过三个月的弃用期后,以下先前弃用的函数已被移除:
  • jax.abstract_arrays.ShapedArray: 使用 jax.core.ShapedArray
  • jax.abstract_arrays.raise_to_shaped: 使用 jax.core.raise_to_shaped
  • jax.numpy.alltrue: 使用 jax.numpy.all
  • jax.numpy.sometrue: 使用 jax.numpy.any
  • jax.numpy.product: 使用 jax.numpy.prod
  • jax.numpy.cumproduct: 使用 jax.numpy.cumprod
  • 弃用/移除:
  • 内部子模块 jax.prng 现已弃用。其内容可在 jax.extend.random 中找到。
  • 内部子模块路径 jax.linear_util 已被弃用。请使用 jax.extend.linear_util 替代(jax.extend 的一部分:扩展模块)。
  • jax.random.PRNGKeyArrayjax.random.KeyArray 已弃用。请在类型注释中使用 jax.Array,并在运行时使用 jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) 来检测类型化的 PRNG 密钥。
  • 方法 PRNGKeyArray.unsafe_raw_array 已弃用。请改用 jax.random.key_data()
  • jax.experimental.pjit.with_sharding_constraint 已弃用。请使用 jax.lax.with_sharding_constraint 替代。
  • 内部工具函数 jax.core.is_opaque_dtypejax.core.has_opaque_dtype 已被移除。不透明数据类型已重命名为扩展数据类型;请使用 jnp.issubdtype(dtype, jax.dtypes.extended) 替代(自 jax v0.4.14 起可用)。
  • 实用工具函数 jax.interpreters.xla.register_collective_primitive 已被移除。在最新的 JAX 发行版中,此实用工具无任何作用,可以安全移除其调用。
  • 内部子模块路径 jax.linear_util 已被弃用。请使用 jax.extend.linear_util 替代(jax.extend 的一部分:扩展模块)。

jaxlib 0.4.16(2023 年 9 月 18 日)

  • 变更:
  • 在 NVIDIA GPU 上,通过实验性的 jax 稀疏 API 进行的稀疏 CSR 矩阵乘法不再使用确定性算法。此更改是为了与 CUDA 12.2.1 兼容性而进行的。
  • Bug 修复:

jax 0.4.14(2023 年 7 月 27 日)

  • 变更:
  • jax.jit 接受 donate_argnames 作为参数。其语义类似于 static_argnames。如果既不提供 donate_argnums 也不提供 donate_argnames,则不会捐赠任何参数。如果不提供 donate_argnums 但提供了 donate_argnames,或者反之,则 JAX 使用 inspect.signature(fun) 来查找与 donate_argnames(或其反向)相对应的任何位置参数。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,并且只有实际参数列在 donate_argnumsdonate_argnames 中将被捐赠。
  • jax.random.gamma() 已重新设计为更高效的算法,具有更健壮的端点行为(#16779)。这意味着给定 key 的值序列在 JAX v0.4.13 和 v0.4.14 之间的 gamma 和相关抽样器(包括 jax.random.ball()jax.random.beta()jax.random.chisquare()jax.random.dirichlet()jax.random.generalized_normal()jax.random.loggamma()jax.random.t())将发生变化。
  • 删除项:
  • 自弃用以来已超过 3 个月的 in_axis_resourcesout_axis_resources 已从 pjit 中删除。请使用 in_shardingsout_shardings 进行替换。这是一个安全和简单的名称替换。它不会改变任何当前的 pjit 语义,也不会破坏任何代码。您仍然可以将 PartitionSpecs 传递给 in_shardingsout_shardings
  • 弃用项:
  • jax.Array.broadcast: 改用 jax.lax.broadcast()
  • jax.Array.broadcast_in_dim: 改用 jax.lax.broadcast_in_dim()
  • jax.Array.split: 使用 jax.numpy.split() 替代。
  • 在之前的弃用之后,以下 API 已被移除:
  • jax.ad: 使用 jax.interpreters.ad
  • jax.curry: 使用 curry = lambda f: partial(partial, f)
  • jax.partial_eval: 使用 jax.interpreters.partial_eval
  • jax.pxla: 使用 jax.interpreters.pxla
  • jax.xla: 使用 jax.interpreters.xla
  • jax.ShapedArray: 使用 jax.core.ShapedArray
  • jax.interpreters.pxla.device_put: 使用 jax.device_put()
  • jax.interpreters.pxla.make_sharded_device_array: 使用 jax.make_array_from_single_device_arrays()
  • jax.interpreters.pxla.ShardedDeviceArray: 使用 jax.Array
  • jax.numpy.DeviceArray: 使用 jax.Array
  • jax.stages.Compiled.compiler_ir: 使用 jax.stages.Compiled.as_text()
  • 破坏性变更:
  • JAX 现在要求 ml_dtypes 版本 0.2.0 或更新版本。
  • 为了修复一个边缘情况,调用 jax.lax.cond() 时,如果第二个和第三个参数是可调用的,则使用五个参数总是解析为文档中记录的 “common operands” cond 行为,即使其他操作数也是可调用的。参见 #16413
  • 已删除无效配置选项 jax_arrayjax_jit_pjit_api_merge。这些选项默认情况下自许多版本以来都为 true。
  • 新功能:


JAX 中文文档(十六)(3)https://developer.aliyun.com/article/1559729

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
并行计算 算法框架/工具 异构计算
JAX 中文文档(十六)(5)
JAX 中文文档(十六)
67 2
|
4月前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
44 2
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
35 1
|
4月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
51 2
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(3)
JAX 中文文档(十六)
78 0
|
4月前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
29 0