JAX 中文文档(四)(2)https://developer.aliyun.com/article/1559793
类型提升语义
此文档描述了 JAX 的类型提升规则,即每对类型的 jax.numpy.promote_types()
结果。关于以下设计考虑的背景,请参阅 Design of Type Promotion Semantics for JAX。
JAX 的类型提升行为通过以下类型提升格确定:
其中,例如:
b1
表示np.bool_
,i2
表示np.int16
,u4
表示np.uint32
,bf
表示np.bfloat16
,f2
表示np.float16
。c8
表示np.complex64
,i*
表示 Python 的int
或弱类型的int
,f*
表示 Python 的float
或弱类型的float
,以及c*
表示 Python 的complex
或弱类型的complex
。
(关于弱类型的更多信息,请参阅下文的 JAX 中的弱类型值。)
任意两种类型之间的提升由它们在此格中的 join 决定,生成以下二进制提升表:
b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* | |
b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u1 | f* | c* |
u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u2 | f* | c* |
u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u4 | f* | c* |
u8 | u8 | u8 | u8 | u8 | u8 | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | u8 | f* | c* |
i1 | i1 | i2 | i4 | i8 | f* | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i1 | f* | c* |
i2 | i2 | i2 | i4 | i8 | f* | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i2 | f* | c* |
i4 | i4 | i4 | i4 | i8 | f* | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i4 | f* | c* |
i8 | i8 | i8 | i8 | i8 | f* | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i8 | f* | c* |
bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c8 | c16 | bf | bf | c8 |
f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c8 | c16 | f2 | f2 | c8 |
f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c8 | c16 | f4 | f4 | c8 |
f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c16 | c16 | f8 | f8 | c16 |
c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c16 | c8 | c16 | c8 | c8 | c8 |
c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 |
i* | i* | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c8 | c8 | c8 | c16 | c8 | c16 | c* | c* | c* |
JAX 的类型提升规则与 NumPy 的不同,如numpy.promote_types()
所示,在上述表格中以绿色背景标出的单元格中。主要有三类区别:
- 当将弱类型值与相同类别的 JAX 类型化值进行提升时,JAX 总是偏向于 JAX 值的精度。例如,
jnp.int16(1) + 1
将返回int16
而不是像 NumPy 中那样提升为int64
。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上述格子结构进行类型提升。例如,jnp.int16(1) + np.array(1)
将返回int64
。 - 当将整数或布尔类型与浮点或复数类型进行提升时,JAX 总是偏向于浮点或复数类型的类型。
- JAX 支持bfloat16非标准的 16 位浮点类型 (
jax.numpy.bfloat16
),这对神经网络训练非常有用。唯一显著的提升行为是对 IEEE-754float16
的处理,其中bfloat16
提升为float32
。
NumPy 和 JAX 之间的差异是因为加速设备(如 GPU 和 TPU)在使用 64 位浮点类型时要么支付显著的性能代价(GPU),要么根本不支持 64 位浮点类型(TPU)。经典 NumPy 的提升规则过于倾向于过度提升到 64 位类型,这对设计用于加速器上运行的系统来说是个问题。
JAX 使用的浮点提升规则更适用于现代加速设备,并且在浮点类型的提升上更为谨慎。JAX 用于浮点类型的提升规则类似于 PyTorch 的规则。
Python 运算符分派的效果
请记住,Python 运算符如加号(+)会根据两个待加值的 Python 类型进行分派。这意味着,例如 np.int16(1) + 1
将按照 NumPy 的规则进行提升,而 jnp.int16(1) + 1
则按照 JAX 的规则进行提升。当两种提升类型结合使用时,可能导致令人困惑的非关联提升语义;例如 np.int16(1) + 1 + jnp.int16(1)
。
JAX 中的弱类型数值
在大多数情况下,JAX 中的弱类型值可以被视为具有与 Python 标量等效的提升行为,例如以下整数标量 2
:
>>> x = jnp.arange(5, dtype='int8') >>> 2 * x Array([0, 2, 4, 6, 8], dtype=int8)
JAX 的弱类型框架旨在防止在 JAX 值与没有明确用户指定类型的值(如 Python 标量文字)之间的二进制操作中出现不需要的类型提升。例如,如果 2
不被视为弱类型,则上述表达式将导致隐式类型提升。
>>> jnp.int32(2) * x Array([0, 2, 4, 6, 8], dtype=int32)
在 JAX 中使用时,Python 标量有时会被提升为DeviceArray
对象,例如在 JIT 编译期间。为了在这种情况下保持所需的提升语义,DeviceArray
对象携带一个weak_type
标志,该标志可以在数组的字符串表示中看到:
>>> jnp.asarray(2) Array(2, dtype=int32, weak_type=True)
如果显式指定了dtype
,则会导致标准的强类型数组值:
>>> jnp.asarray(2, dtype='int32') Array(2, dtype=int32) ```## 严格的 dtype 提升 在某些情况下,禁用隐式类型提升行为并要求所有提升都是显式的可能很有用。可以通过在 JAX 中将`jax_numpy_dtype_promotion`标志设置为`'strict'`来实现。在本地,可以通过上下文管理器来完成: ```py >>> x = jnp.float32(1) >>> y = jnp.int32(1) >>> with jax.numpy_dtype_promotion('strict'): ... z = x + y ... Traceback (most recent call last): TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
为了方便起见,严格提升模式仍将允许安全的弱类型提升,因此您仍然可以编写混合使用 JAX 数组和 Python 标量的代码:
>>> with jax.numpy_dtype_promotion('strict'): ... z = x + 1 >>> print(z) 2.0
如果您希望全局设置配置,则可以使用标准配置更新:
jax.config.update('jax_numpy_dtype_promotion', 'strict')
要恢复默认的标准类型提升,请将此配置设置为'standard'
:
jax.config.update('jax_numpy_dtype_promotion', 'standard')
Pytrees
什么是 pytree?
在 JAX 中,我们使用术语pytree来指代由类似容器的 Python 对象构建的类似树的结构。如果它们在 pytree 注册中,则类被视为容器类,默认包括列表、元组和字典。也就是说:
- 任何类型不在 pytree 容器注册中的对象被视为叶 pytree;
- 任何类型在 pytree 容器注册中的对象,并且包含 pytrees,被视为 pytree。
对于 pytree 容器注册中的每个条目,注册了类似容器的类型,具有一对函数,用于指定如何将容器类型的实例转换为(children, metadata)
对,以及如何将这样的对返回为容器类型的实例。使用这些函数,JAX 可以将任何已注册容器对象的树规范化为元组。
示例 pytrees:
[1, "a", object()] # 3 leaves (1, (2, 3), ()) # 3 leaves [1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
JAX 可以扩展以将其他容器类型视为 pytrees;请参见下面的扩展 pytrees。
Pytrees 和 JAX 函数
许多 JAX 函数,比如 jax.lax.scan()
,操作数组的 pytrees。JAX 函数变换可以应用于接受输入和产生输出为数组 pytrees 的函数。
将可选参数应用于 pytrees
某些 JAX 函数变换接受可选参数,用于指定如何处理特定输入或输出值(例如 vmap()
的 in_axes
和 out_axes
参数)。这些参数也可以是 pytrees,它们的结构必须与相应参数的 pytree 结构对应。特别地,在能够“匹配”这些参数 pytrees 中的叶子与参数 pytrees 中的值的情况下,通常限制参数 pytrees 为参数 pytrees 的树前缀。
例如,如果我们将以下输入传递给 vmap()
(注意函数的输入参数被视为元组):
(a1, {"k1": a2, "k2": a3})
我们可以使用以下 in_axes
pytree 指定仅映射k2
参数(axis=0
),其余参数不映射(axis=None
):
(None, {"k1": None, "k2": 0})
可选参数 pytree 结构必须与主输入 pytree 相匹配。但是,可选参数可以选择指定为“前缀” pytree,这意味着可以将单个叶值应用于整个子 pytree。例如,如果我们有与上述相同的 vmap()
输入,但希望仅映射字典参数,我们可以使用:
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果我们希望映射每个参数,可以简单地编写一个应用于整个参数元组 pytree 的单个叶值:
0
这恰好是vmap()
的默认in_axes
值!
相同的逻辑适用于指定转换函数的其他可选参数,例如 vmap
的 out_axes
。
查看对象的 pytree 定义
为了调试目的查看任意对象的 pytree 定义,可以使用:
from jax.tree_util import tree_structure print(tree_structure(object))
开发者信息
这主要是 JAX 内部文档,终端用户不应需要理解这一点来使用 JAX,除非在向 JAX 注册新的用户定义容器类型时。某些细节可能会更改。
内部 pytree 处理
JAX 在api.py
边界(以及控制流原语中)将 pytrees 展平为叶子列表。这使得下游 JAX 内部更简单:像grad()
、jit()
和vmap()
这样的转换可以处理接受并返回各种不同 Python 容器的用户函数,而系统的其他部分可以处理仅接受(多个)数组参数并始终返回扁平数组列表的函数。
当 JAX 展开 pytree 时,它将生成叶子列表和一个treedef
对象,该对象编码原始值的结构。然后可以使用treedef
来在转换叶子后构造匹配的结构化值。Pytrees 类似于树,而不是 DAG 或图,我们处理它们时假设具有引用透明性并且不能包含引用循环。
这里有一个简单的例子:
from jax.tree_util import tree_flatten, tree_unflatten import jax.numpy as jnp # The structured value to be transformed value_structured = [1., (2., 3.)] # The leaves in value_flat correspond to the `*` markers in value_tree value_flat, value_tree = tree_flatten(value_structured) print(f"{value_flat=}\n{value_tree=}") # Transform the flat value list using an element-wise numeric transformer transformed_flat = list(map(lambda v: v * 2., value_flat)) print(f"{transformed_flat=}") # Reconstruct the structured output, using the original transformed_structured = tree_unflatten(value_tree, transformed_flat) print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0] value_tree=PyTreeDef([*, (*, *)]) transformed_flat=[2.0, 4.0, 6.0] transformed_structured=[2.0, (4.0, 6.0)]
默认情况下,pytree 容器可以是列表、元组、字典、命名元组、None、OrderedDict。其他类型的值,包括数值和 ndarray 值,都被视为叶子节点:
from collections import namedtuple Point = namedtuple('Point', ['x', 'y']) example_containers = [ (1., [2., 3.]), (1., {'b': 2., 'a': 3.}), 1., None, jnp.zeros(2), Point(1., 2.) ] def show_example(structured): flat, tree = tree_flatten(structured) unflattened = tree_unflatten(tree, flat) print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") for structured in example_containers: show_example(structured)
structured=(1.0, [2.0, 3.0]) flat=[1.0, 2.0, 3.0] tree=PyTreeDef((*, [*, *])) unflattened=(1.0, [2.0, 3.0]) structured=(1.0, {'b': 2.0, 'a': 3.0}) flat=[1.0, 3.0, 2.0] tree=PyTreeDef((*, {'a': *, 'b': *})) unflattened=(1.0, {'a': 3.0, 'b': 2.0}) structured=1.0 flat=[1.0] tree=PyTreeDef(*) unflattened=1.0 structured=None flat=[] tree=PyTreeDef(None) unflattened=None structured=Array([0., 0.], dtype=float32) flat=[Array([0., 0.], dtype=float32)] tree=PyTreeDef(*) unflattened=Array([0., 0.], dtype=float32) structured=Point(x=1.0, y=2.0) flat=[1.0, 2.0] tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *])) unflattened=Point(x=1.0, y=2.0)
扩展 pytrees
默认情况下,被视为结构化值的任何部分,如果未被识别为内部 pytree 节点(即类似容器的)则被视为叶子节点:
class Special(object): def __init__(self, x, y): self.x = x self.y = y def __repr__(self): return "Special(x={}, y={})".format(self.x, self.y) show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0) flat=[Special(x=1.0, y=2.0)] tree=PyTreeDef(*) unflattened=Special(x=1.0, y=2.0)
被视为内部 pytree 节点的 Python 类型集是可扩展的,通过全局类型注册表,注册类型的值被递归遍历。要注册新类型,可以使用register_pytree_node()
:
from jax.tree_util import register_pytree_node class RegisteredSpecial(Special): def __repr__(self): return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) def special_flatten(v): """Specifies a flattening recipe. Params: v: the value of registered type to flatten. Returns: a pair of an iterable with the children to be flattened recursively, and some opaque auxiliary data to pass back to the unflattening recipe. The auxiliary data is stored in the treedef for use during unflattening. The auxiliary data could be used, e.g., for dictionary keys. """ children = (v.x, v.y) aux_data = None return (children, aux_data) def special_unflatten(aux_data, children): """Specifies an unflattening recipe. Params: aux_data: the opaque data that was specified during flattening of the current treedef. children: the unflattened children Returns: a re-constructed object of the registered type, using the specified children and auxiliary data. """ return RegisteredSpecial(*children) # Global registration register_pytree_node( RegisteredSpecial, special_flatten, # tell JAX what are the children nodes special_unflatten # tell JAX how to pack back into a RegisteredSpecial ) show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0) flat=[1.0, 2.0] tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *])) unflattened=RegisteredSpecial(x=1.0, y=2.0)
或者,您可以在您的类上定义适当的tree_flatten
和tree_unflatten
方法,并使用register_pytree_node_class()
进行装饰:
from jax.tree_util import register_pytree_node_class @register_pytree_node_class class RegisteredSpecial2(Special): def __repr__(self): return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y) def tree_flatten(self): children = (self.x, self.y) aux_data = None return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0) flat=[1.0, 2.0] tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *])) unflattened=RegisteredSpecial2(x=1.0, y=2.0)
在定义展开函数时,一般而言children
应包含数据结构的所有动态元素(数组、动态标量和 pytrees),而aux_data
应包含将被滚入treedef
结构的所有静态元素。有时 JAX 需要比较treedef
以确保辅助数据在扁平化过程中支持有意义的哈希和相等比较,因此必须小心处理。
操作 pytree 的所有函数都在jax.tree_util
中。
自定义 PyTrees 和初始化
用户定义的 PyTree 对象的一个常见问题是,JAX 转换有时会使用意外的值初始化它们,因此初始化时进行的任何输入验证可能会失败。例如:
class MyTree: def __init__(self, a): self.a = jnp.asarray(a) register_pytree_node(MyTree, lambda tree: ((tree.a,), None), lambda _, args: MyTree(*args)) tree = MyTree(jnp.arange(5.0)) jax.vmap(lambda x: x)(tree) # Error because object() is passed to MyTree. jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to MyTree
在第一种情况下,JAX 的内部使用object()
值的数组来推断树的结构;在第二种情况下,将树映射到树的函数的雅可比矩阵定义为树的树。
因此,自定义 PyTree 类的 __init__
和 __new__
方法通常应避免进行任何数组转换或其他输入验证,或者预期并处理这些特殊情况。例如:
class MyTree: def __init__(self, a): if not (type(a) is object or a is None or isinstance(a, MyTree)): a = jnp.asarray(a) self.a = a
另一个可能性是,结构化你的 tree_unflatten
函数,避免调用 __init__
;例如:
def tree_unflatten(aux_data, children): del aux_data # unused in this class obj = object.__new__(MyTree) obj.a = a return obj
如果你选择这条路线,请确保你的 tree_unflatten
函数在代码更新时与 __init__
保持同步。
JAX 中文文档(四)(4)https://developer.aliyun.com/article/1559796