JAX 中文文档(三)(2)https://developer.aliyun.com/article/1559703
jax.debug.print
和 jax.debug.breakpoint
原文:
jax.readthedocs.io/en/latest/debugging/print_breakpoint.html
jax.debug
包为检查在 JIT 函数中的值提供了一些有用的工具。
使用 jax.debug.print
和其他调试回调进行调试
TL;DR 使用 jax.debug.print()
在 jit
和 pmap
装饰函数中将跟踪的数组值打印到标准输出:
import jax import jax.numpy as jnp @jax.jit def f(x): jax.debug.print("🤯 {x} 🤯", x=x) y = jnp.sin(x) jax.debug.print("🤯 {y} 🤯", y=y) return y f(2.) # Prints: # 🤯 2.0 🤯 # 🤯 0.9092974662780762 🤯
对于一些转换,如 jax.grad
和 jax.vmap
,可以使用 Python 的内置 print
函数打印数值。但是 print
在 jax.jit
或 jax.pmap
下不起作用,因为这些转换会延迟数值评估。因此,请使用 jax.debug.print
代替!
语义上,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None: print(fmt.format(*args, **kwargs))
除了可以被 JAX 分阶段化和转换外。有关更多详细信息,请参阅 API 参考
。
注意,fmt
不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print
,我们希望延迟到稍后再格式化。
何时使用“debug”打印?
对于动态(即跟踪的)数组值在 JAX 转换如 jit
、vmap
等中,应使用 jax.debug.print
进行打印。对于静态值(如数组形状或数据类型),可以使用普通的 Python print
语句。
为什么使用“debug”打印?
以调试为名,jax.debug.print
可以显示有关计算如何评估的信息:
xs = jnp.arange(3.) def f(x): jax.debug.print("x: {}", x) y = jnp.sin(x) jax.debug.print("y: {}", y) return y jax.vmap(f)(xs) # Prints: x: 0.0 # x: 1.0 # x: 2.0 # y: 0.0 # y: 0.841471 # y: 0.9092974 jax.lax.map(f, xs) # Prints: x: 0.0 # y: 0.0 # x: 1.0 # y: 0.841471 # x: 2.0 # y: 0.9092974
注意,打印的结果是以不同的顺序显示的!
通过揭示这些内部工作,jax.debug.print
的输出不遵守 JAX 的通常语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
计算相同的东西(以不同的方式)。然而,这些评估顺序的细节正是我们调试时想要看到的!
因此,在重视语义保证时,请使用 jax.debug.print
进行调试。
更多 jax.debug.print
的例子
除了上述使用 jit
和 vmap
的例子外,还有几个需要记住的例子。
在 jax.pmap
下打印
当使用 jax.pmap
时,jax.debug.print
可能会被重新排序!
xs = jnp.arange(2.) def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) # Prints: x: 1.0 # x: 0.0 # OR # Prints: x: 1.0 # x: 0.0
在 jax.grad
下打印
在 jax.grad
下,jax.debug.print
只会在前向传播时打印:
def f(x): jax.debug.print("x: {}", x) return x * 2. jax.grad(f)(1.) # Prints: x: 1.0
这种行为类似于 Python 内置的 print
在 jax.grad
下的工作方式。但在这里使用 jax.debug.print
,即使调用者应用 jax.jit
,行为也是相同的。
要在反向传播中打印,只需使用 jax.custom_vjp
:
@jax.custom_vjp def print_grad(x): return x def print_grad_fwd(x): return x, None def print_grad_bwd(_, x_grad): jax.debug.print("x_grad: {}", x_grad) return (x_grad,) print_grad.defvjp(print_grad_fwd, print_grad_bwd) def f(x): x = print_grad(x) return x * 2. jax.grad(f)(1.) # Prints: x_grad: 2.0
在其他转换中打印
jax.debug.print
在其他转换如 xmap
和 pjit
中同样适用。
使用 jax.debug.callback
更多控制
实际上,jax.debug.print
是围绕 jax.debug.callback
的一个轻便封装,可以直接使用以更好地控制字符串格式化或输出类型。
语义上,jax.debug.callback
大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None: fun(*args, **kwargs) return None
与 jax.debug.print
类似,这些回调只应用于调试输出,比如打印或绘图。打印和绘图相对无害,但如果用于其他用途,它的行为在转换中可能会让你感到意外。例如,不安全地用于计时操作是不安全的,因为回调可能会被重新排序并且是异步的(见下文)。
锐利的部分
像大多数 JAX API 一样,如果使用不当,jax.debug.print
也会给你带来麻烦。
打印结果的顺序
当 jax.debug.print
的不同调用涉及彼此不依赖的参数时,在分阶段时可能会被重新排序,例如通过 jax.jit
:
@jax.jit def f(x, y): jax.debug.print("x: {}", x) jax.debug.print("y: {}", y) return x + y f(2., 3.) # Prints: x: 2.0 # y: 3.0 # OR # Prints: y: 3.0 # x: 2.0
为什么?在幕后,编译器获得了一个计算的功能表示,其中 Python 函数的命令顺序丢失,只有数据依赖性保留。对于功能纯粹的代码用户来说,这种变化是看不见的,但是在像打印这样的副作用存在时,就会显而易见。
要保持 jax.debug.print
在 Python 函数中的原始顺序,可以使用 jax.debug.print(..., ordered=True)
,这将确保打印的相对顺序保持不变。但是在 jax.pmap
和涉及并行性的其他 JAX 转换中使用 ordered=True
会引发错误,因为在并行执行中无法保证顺序。
异步回调
根据后端不同,jax.debug.print
可能会异步执行,即不在主程序线程中。这意味着值可能在您的 JAX 函数返回值后才被打印到屏幕上。
@jax.jit def f(x): jax.debug.print("x: {}", x) return x f(2.).block_until_ready() # <do something else> # Prints: x: 2.
要阻塞函数中的 jax.debug.print
,您可以调用 jax.effects_barrier()
,它会等待函数中任何剩余的副作用也完成:
@jax.jit def f(x): jax.debug.print("x: {}", x) return x f(2.).block_until_ready() jax.effects_barrier() # Prints: x: 2. # <do something else>
性能影响
不必要的实现
虽然 jax.debug.print
设计为性能影响最小,但它可能会干扰编译器优化,并且可能会影响 JAX 程序的内存配置文件。
def f(w, b, x): logits = w.dot(x) + b jax.debug.print("logits: {}", logits) return jax.nn.relu(logits)
在这个例子中,我们在线性层和激活函数之间打印中间值。像 XLA 这样的编译器可以执行融合优化,可以避免在内存中实现 logits
。但是当我们在 logits
上使用 jax.debug.print
时,我们强制这些中间值被实现,可能会减慢程序速度并增加内存使用。
此外,当使用 jax.debug.print
与 jax.pjit
时,会发生全局同步,将值实现在单个设备上。
回调开销
jax.debug.print
本质上会在加速器和其主机之间进行通信。底层机制因后端而异(例如 GPU vs TPU),但在所有情况下,我们需要将打印的值从设备复制到主机。在 CPU 情况下,此开销较小。
此外,当使用 jax.debug.print
与 jax.pjit
时,会发生全局同步,增加了一些额外开销。
jax.debug.print
的优势和限制
优势
- 打印调试简单直观
jax.debug.callback
可用于其他无害的副作用
限制
- 添加打印语句是一个手动过程
- 可能会对性能产生影响
使用 jax.debug.breakpoint()
进行交互式检查
TL;DR 使用 jax.debug.breakpoint()
暂停执行您的 JAX 程序以检查值:
@jax.jit def f(x): y, z = jnp.sin(x), jnp.cos(x) jax.debug.breakpoint() return y * z f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是 jax.debug.callback(...)
的一种应用,用于捕获调用堆栈信息。因此它与 jax.debug.print
具有相同的转换行为(例如,对 jax.debug.breakpoint()
进行 vmap
-ing 会将其展开到映射的轴上)。
用法
在编译的 JAX 函数中调用 jax.debug.breakpoint()
会在命中断点时暂停程序。您将看到一个类似 pdb
的提示符,允许您检查调用堆栈中的值。与 pdb
不同的是,您不能逐步执行程序,但可以恢复执行。
调试器命令:
help
- 打印出可用的命令p
- 评估表达式并打印其结果pp
- 评估表达式并漂亮地打印其结果u(p)
- 上移一个堆栈帧d(own)
- 下移一个堆栈帧w(here)/bt
- 打印出回溯l(ist)
- 打印出代码上下文c(ont(inue))
- 恢复程序的执行q(uit)/exit
- 退出程序(在 TPU 上不起作用)
示例
与 jax.lax.cond
结合使用
当与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x): is_finite = jnp.isfinite(x).all() def true_fn(x): pass def false_fn(x): jax.debug.breakpoint() lax.cond(is_finite, true_fn, false_fn, x) @jax.jit def f(x, y): z = x / y breakpoint_if_nonfinite(z) return z f(2., 0.) # ==> Pauses during execution!
锐利的特性
因为 jax.debug.breakpoint
只是 jax.debug.callback
的一种应用,所以它与 jax.debug.print
一样具有锐利的特性,但也有一些额外的注意事项:
jax.debug.breakpoint
比jax.debug.print
更多地实现了中间值,因为它强制实现了调用堆栈中的所有值。jax.debug.breakpoint
的运行时开销比jax.debug.print
更大,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优势和限制
优势
- 简单、直观且(在某种程度上)标准
- 可以同时检查多个值,上下跟踪调用堆栈。
限制
- 可能需要使用多个断点来准确定位错误的源头
- 会产生许多中间值
JAX 中文文档(三)(4)https://developer.aliyun.com/article/1559705