JAX 中文文档(三)(3)

简介: JAX 中文文档(三)

JAX 中文文档(三)(2)https://developer.aliyun.com/article/1559703


jax.debug.printjax.debug.breakpoint

原文:jax.readthedocs.io/en/latest/debugging/print_breakpoint.html

jax.debug 包为检查在 JIT 函数中的值提供了一些有用的工具。

使用 jax.debug.print 和其他调试回调进行调试

TL;DR 使用 jax.debug.print()jitpmap 装饰函数中将跟踪的数组值打印到标准输出:

import jax
import jax.numpy as jnp
@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯 

对于一些转换,如 jax.gradjax.vmap,可以使用 Python 的内置 print 函数打印数值。但是 printjax.jitjax.pmap 下不起作用,因为这些转换会延迟数值评估。因此,请使用 jax.debug.print 代替!

语义上,jax.debug.print 大致等同于以下 Python 函数

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs)) 

除了可以被 JAX 分阶段化和转换外。有关更多详细信息,请参阅 API 参考

注意,fmt 不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print,我们希望延迟到稍后再格式化。

何时使用“debug”打印?

对于动态(即跟踪的)数组值在 JAX 转换如 jitvmap 等中,应使用 jax.debug.print 进行打印。对于静态值(如数组形状或数据类型),可以使用普通的 Python print 语句。

为什么使用“debug”打印?

以调试为名,jax.debug.print 可以显示有关计算如何评估的信息:

xs = jnp.arange(3.)
def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974 

注意,打印的结果是以不同的顺序显示的!

通过揭示这些内部工作,jax.debug.print 的输出不遵守 JAX 的通常语义保证,例如 jax.vmap(f)(xs)jax.lax.map(f, xs) 计算相同的东西(以不同的方式)。然而,这些评估顺序的细节正是我们调试时想要看到的!

因此,在重视语义保证时,请使用 jax.debug.print 进行调试。

更多 jax.debug.print 的例子

除了上述使用 jitvmap 的例子外,还有几个需要记住的例子。

jax.pmap 下打印

当使用 jax.pmap 时,jax.debug.print 可能会被重新排序!

xs = jnp.arange(2.)
def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0 
jax.grad 下打印

jax.grad 下,jax.debug.print 只会在前向传播时打印:

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0 

这种行为类似于 Python 内置的 printjax.grad 下的工作方式。但在这里使用 jax.debug.print,即使调用者应用 jax.jit,行为也是相同的。

要在反向传播中打印,只需使用 jax.custom_vjp

@jax.custom_vjp
def print_grad(x):
  return x
def print_grad_fwd(x):
  return x, None
def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0 
在其他转换中打印

jax.debug.print 在其他转换如 xmappjit 中同样适用。

使用 jax.debug.callback 更多控制

实际上,jax.debug.print 是围绕 jax.debug.callback 的一个轻便封装,可以直接使用以更好地控制字符串格式化或输出类型。

语义上,jax.debug.callback 大致等同于以下 Python 函数

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None 

jax.debug.print 类似,这些回调只应用于调试输出,比如打印或绘图。打印和绘图相对无害,但如果用于其他用途,它的行为在转换中可能会让你感到意外。例如,不安全地用于计时操作是不安全的,因为回调可能会被重新排序并且是异步的(见下文)。

锐利的部分

像大多数 JAX API 一样,如果使用不当,jax.debug.print 也会给你带来麻烦。

打印结果的顺序

jax.debug.print 的不同调用涉及彼此不依赖的参数时,在分阶段时可能会被重新排序,例如通过 jax.jit

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y
f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0 

为什么?在幕后,编译器获得了一个计算的功能表示,其中 Python 函数的命令顺序丢失,只有数据依赖性保留。对于功能纯粹的代码用户来说,这种变化是看不见的,但是在像打印这样的副作用存在时,就会显而易见。

要保持 jax.debug.print 在 Python 函数中的原始顺序,可以使用 jax.debug.print(..., ordered=True),这将确保打印的相对顺序保持不变。但是在 jax.pmap 和涉及并行性的其他 JAX 转换中使用 ordered=True 会引发错误,因为在并行执行中无法保证顺序。

异步回调

根据后端不同,jax.debug.print 可能会异步执行,即不在主程序线程中。这意味着值可能在您的 JAX 函数返回值后才被打印到屏幕上。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2. 

要阻塞函数中的 jax.debug.print,您可以调用 jax.effects_barrier(),它会等待函数中任何剩余的副作用也完成:

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else> 
性能影响
不必要的实现

虽然 jax.debug.print 设计为性能影响最小,但它可能会干扰编译器优化,并且可能会影响 JAX 程序的内存配置文件。

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits) 

在这个例子中,我们在线性层和激活函数之间打印中间值。像 XLA 这样的编译器可以执行融合优化,可以避免在内存中实现 logits。但是当我们在 logits 上使用 jax.debug.print 时,我们强制这些中间值被实现,可能会减慢程序速度并增加内存使用。

此外,当使用 jax.debug.printjax.pjit 时,会发生全局同步,将值实现在单个设备上。

回调开销

jax.debug.print 本质上会在加速器和其主机之间进行通信。底层机制因后端而异(例如 GPU vs TPU),但在所有情况下,我们需要将打印的值从设备复制到主机。在 CPU 情况下,此开销较小。

此外,当使用 jax.debug.printjax.pjit 时,会发生全局同步,增加了一些额外开销。

jax.debug.print 的优势和限制

优势
  • 打印调试简单直观
  • jax.debug.callback 可用于其他无害的副作用
限制
  • 添加打印语句是一个手动过程
  • 可能会对性能产生影响

使用 jax.debug.breakpoint() 进行交互式检查

TL;DR 使用 jax.debug.breakpoint() 暂停执行您的 JAX 程序以检查值:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution! 


jax.debug.breakpoint() 实际上只是 jax.debug.callback(...) 的一种应用,用于捕获调用堆栈信息。因此它与 jax.debug.print 具有相同的转换行为(例如,对 jax.debug.breakpoint() 进行 vmap-ing 会将其展开到映射的轴上)。

用法

在编译的 JAX 函数中调用 jax.debug.breakpoint() 会在命中断点时暂停程序。您将看到一个类似 pdb 的提示符,允许您检查调用堆栈中的值。与 pdb 不同的是,您不能逐步执行程序,但可以恢复执行。

调试器命令:

  • help - 打印出可用的命令
  • p - 评估表达式并打印其结果
  • pp - 评估表达式并漂亮地打印其结果
  • u(p) - 上移一个堆栈帧
  • d(own) - 下移一个堆栈帧
  • w(here)/bt - 打印出回溯
  • l(ist) - 打印出代码上下文
  • c(ont(inue)) - 恢复程序的执行
  • q(uit)/exit - 退出程序(在 TPU 上不起作用)

示例

jax.lax.cond 结合使用

当与 jax.lax.cond 结合使用时,调试器可以成为检测 naninf 的有用工具。

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution! 

锐利的特性

因为 jax.debug.breakpoint 只是 jax.debug.callback 的一种应用,所以它与 jax.debug.print 一样具有锐利的特性,但也有一些额外的注意事项:

  • jax.debug.breakpointjax.debug.print 更多地实现了中间值,因为它强制实现了调用堆栈中的所有值。
  • jax.debug.breakpoint 的运行时开销比 jax.debug.print 更大,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。

jax.debug.breakpoint() 的优势和限制

优势
  • 简单、直观且(在某种程度上)标准
  • 可以同时检查多个值,上下跟踪调用堆栈。
限制
  • 可能需要使用多个断点来准确定位错误的源头
  • 会产生许多中间值


JAX 中文文档(三)(4)https://developer.aliyun.com/article/1559705

相关文章
|
4月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
36 0
JAX 中文文档(六)(1)
|
4月前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
29 0
|
4月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
50 0
|
4月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
56 0
|
4月前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
31 0
|
4月前
|
编译器 API 异构计算
JAX 中文文档(一)(2)
JAX 中文文档(一)
60 0
|
4月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
115 0
|
4月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
33 0
|
4月前
|
存储 并行计算 开发工具
JAX 中文文档(十)(1)
JAX 中文文档(十)
46 0
|
4月前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
27 0