JAX 中文文档(五)(5)

简介: JAX 中文文档(五)

JAX 中文文档(五)(4)https://developer.aliyun.com/article/1559812

使用 Pallas 编写 TPU 内核

原文:jax.readthedocs.io/en/latest/pallas/tpu/details.html

本页关注试图在 Google TPU 上运行 Pallas 内核时的重要细节。首先,TPU 后端仍处于实验阶段,并且只接受 JAX  NumPy 的子集。此外,为 TPU  编写高性能代码可能需要仔细考虑硬件的本机能力。虽然许多对硬件不自然的模式将被接受,但它们最终可能需要软件模拟,并可能减慢计算速度。

警告

此功能仍应视为实验性功能,因为工作仍在进行中(特别是在改进错误消息方面)。

注意

虽然此处描述的所有功能都是实验性的,但我们仍然非常认真地维护其正确性。因此,在尝试编写 TPU 内核时可能看到“未实现”错误并不罕见。但是,如果编译器接受了内核,它必须返回预期的结果。

如果您看到意外的输出,请将其与传递interpret=Truepallas_call的内核运行进行比较。如果结果不一致,请提交错误报告

什么是 TPU?

TPU 是 Google 开发的硬件加速器。您可以将 TPU 视为专门用于机器学习工作负载的  GPU。因此,它们的架构有相当大的差异。然而,我们相信 Pallas 可以使您轻松开始编写 TPU  内核,即使您没有完全理解底层硬件也是如此。话虽如此,深入了解硬件将确实使编写高性能内核变得更加容易。

简言之,TPU 与 GPU 的主要区别在于 TPU 是顺序机器,具有非常宽的向量寄存器(类似于  CPU!)。与此同时,它们允许软件安排某些操作在后台执行,使其与主指令流异步执行。这包括 HBM 内存访问(无法直接发出,而是必须通过 DMA  子单元预取到较低层次的内存层次结构)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)。

如果您对详细了解 TPU 架构感兴趣,我们建议阅读多年来发表的一系列论文集。虽然许多论文谈论特定的 TPU 代,但其中许多描述的思想也适用于后续代。

值得注意的属性和限制

BlockSpecs 和网格迭代

在 Pallas 中,BlockSpecs 通常按预期行为——每次核心体调用都会访问输入的片段,并且旨在初始化输出的一个片段。

警告

并非所有的窗口形状都受支持。如果你的输入的最后两个维度分别大于 8 和 128,那么这些维度中的窗口形状必须是对应因子的倍数。如果输入维度较小,则窗口应跨越整个维度。

Pallas TPU 核心的一个有趣方面是它们处理内存空间的方式:虽然pallas_call的输入通常驻留在 HBM(主 TPU 内存)中,但传递到核心体的引用将指向内存层次结构较低的缓冲区(VMEM 或 SMEM)。这使得核心体能够以非常高的速度读写它们,而所有与 HBM 的通信(具有非常高的延迟)由编译器处理并与计算重叠。

此外,与 GPU 相比,TPU 实际上是高度序列化的机器。因此,网格通常不是并行处理的,而是按字典顺序顺序处理(尽管请参阅多核 TPU 配置部分的例外情况)。这解锁了一些有趣的功能:

  • 当两个(按字典顺序)连续的网格索引使用相同输入的片段时,第二次迭代的 HBM 传输将被跳过,因为数据已经可用。
  • 多个核心体调用可以向输出的同一片段写入,而不会有任何竞态条件的风险。但我们确实要求写入特定片段的所有调用是连续的。

关于输出的“连续”限制通常意味着网格维度的某些前缀总是变化,而调用需要访问的输出窗口对于其余后缀保持不变。

例如,在实现矩阵乘法的 Pallas TPU 核心时,通常会使用三维网格:前两个维度对应于沿左操作数的第一轴和第二操作数的第二轴切片。第三和最后网格轴将瓦片化减少维度。与减少维度对应的网格轴必须是最后一个,因为输出窗口沿此轴不变。输出引用随后可用作部分结果的累加器。

注意

对于这样一个低级内存层次结构(16MB+),VMEM  相当大,这使得可以使用较大的窗口大小。通常情况下,窗口大小越大,最终硬件利用率就越好。然而,可能会指定一个窗口大小,该大小(加上保存溢出矢量寄存器所需的空间)超过了  VMEM 的大小。在这种情况下,您可能会看到一个低级编译器错误消息,抱怨内存不足错误。

维度排序是有意义的

在 JAX 程序中,jax.jit内部数组的排序通常不会影响性能,因为编译器可以自由地重新排列它们。但是,由于 Pallas 旨在暴露更低级的功能,维度顺序对生成的代码质量有很大影响。

请记住,TPU 主要在 2D 矢量寄存器上执行大部分计算。Pallas TPU 只会考虑将中间数组的最后两个维度映射到这些矢量寄存器维度(子通道和通道)。形状为(n, 1, 1)的数组保证需要至少n个矢量寄存器来表示。如果n变得太大,则可能会导致溢出,并由于过大的内存占用而导致  VMEM 内存不足错误。但这也可能不会发生 —  低级编译器可以重新排列指令以降低寄存器压力,并且实际上在这方面做得非常好。尽管如此,保持最后两个维度大(特别是最后一个维度),同时使前导维度保持小是一个很好的经验法则。

多核 TPU 配置

在更新的 TPU 生成中,芯片上的两个核心通常被抽象为单个设备。为了利用多个核心,Pallas 必须打破顺序网格执行的保证,并且需要在核心上并行化一个网格轴。这是一个选择加入的过程。为了允许这样做,pallas_call需要一个额外的名为dimension_semantics的参数:

该参数是一个列表,其条目数量与网格中的轴数量相同。只有parallel维度可以在核心上分区。作为一个经验法则,维度是并行的,除非输出窗口不变。因此,dimension_semantics始终是一些parallel轴的数字,后跟一些arbitrary轴的数字。

尽管在 2 核 TPU 设备上分区内核通常会导致 2  倍速度提升,但实际上可能会显著小于此值。特别是如果体的不同实例具有非常不同的成本,这一点尤为真实。如果所有昂贵的步骤都映射到一个核心,而所有廉价的步骤都分配给另一个核心,则第二个核心将在第一个完成其任务之前处于空闲状态。

Pallas TPU 通常偏好将大小为 TPU 核心数量倍数的轴进行分区,并且更喜欢分区主导的网格轴。

将操作数放入 SMEM

大多数 TPU 计算将在向量单元上进行。然而,有许多情况下进行一些标量操作是有用的,例如执行控制流。因此,TPU 配备了一个单独的标量单元,并附有一个单独的标量存储器(SMEM)。按照一个经验法则,用于执行控制流决策的任何数据应放置在 SMEM 中。

SMEM 是一种低延迟内存,支持随机访问,但只能用单个指令读写 32 位值(与 VMEM 事务的 4KBi 粒度相比非常小,但由于没有对齐要求而更加灵活!)。

当实现不按规则模式访问输入块的内核时,标量内存也非常有用,例如编写块稀疏内核时。在 Pallas 中,可以通过将pallas_callgrid参数替换为具有非零num_scalar_prefetch参数的PrefetchScalarGridSpecgrid_spec来实现这一点。如果num_scalar_prefetchn,那么pallas_call的前n个参数将放置在 SMEM 中。对于这些参数,不应指定任何BlockSpec。但是,对于所有后续参数的BlockSpec,不仅会收到网格索引,还会收到领先操作数的 SMEM 引用。

注意

我们正在努力实现此功能的示例。敬请关注!

支持的数据类型

目前,Pallas TPU 仅支持以下数据类型:

  • jnp.float32
  • jnp.bfloat16
  • jnp.int*(所有精度,除了jnp.int4
  • jnp.uint*(所有精度)

计算放置

所有标量(即 0D)数组将存储在标量寄存器中,并在标量核心上执行操作。所有其他操作(甚至是对单个元素但是 1D+数组的操作)将在向量核心上执行。

支持的操作

矩阵乘法

矩阵乘法始终以float32格式生成结果。如果您的输入不是 float32,建议使用lax.dot并将preferred_element_type设置为jnp.float32

当使用lax.dot_general时,可以将矩阵乘法操作数的最后两个维度的转置融合到操作中,这可以提高整体内核性能。

精度控制

Pallas TPU 的降低考虑到了jax.default_matmul_precision。为了获得最佳性能(和最低精度),请使用bfloat16。如果您关心数值精度,可能需要将精度设置为float32

警告

即使将 32 位操作数传递给矩阵乘法,除非请求float32精度,否则它们将会被四舍五入为bfloat16

转置

如果值至少有 4 个维度,则除了最后两个轴以外的任意转置都是免费的。否则,仅实现了最后两个轴的转置。请注意,一些最后两个维度的转置可以融合到矩阵乘法中。

访问内存

可以读取或更新引用的任意片段,受实现约束的限制。目前,对于宽度为 32 位的输入没有限制,但只支持某些更窄类型的切片模式。总是支持最后两个维度中分别是 8 和 128 的倍数的对齐读写。

通常在向量内存的读写发生在形状为 (8, 128) 的瓦片上。因此,当读取或写入至少有两个维度的引用时,最佳性能是在内存访问的基础偏移具有瓦片可整除的索引,并且读取区域的大小是瓦片大小的倍数。

逐元素操作

支持许多逐元素操作。值得注意的是,硬件通常仅支持使用 32 位类型进行逐元素计算。在加载使用较低精度类型的操作数时,通常应先将其升级为 32 位类型再应用逐元素操作。

值得注意的是,它们的成本可能显著不同。因此,我们列出了三类支持的操作:廉价(🟢)、中等(🌕)和昂贵(🔴)。

操作 成本
jnp.add+ 🟢
jnp.sub- 🟢
jnp.mul* 🟢
///% 🌕
jnp.maxjnp.min 🟢
jnp.where(选择) 🟢
jnp.abs 🟢
` ^`,`&`,`~`
<<>> 🟢
比较运算(==,…) 🟢
类型转换(.astype 🟢
jnp.exp 🌕
jnp.tanh 🌕
jnp.pow 🌕
jnp.sin 🔴
jnp.cos 🔴

许多 JAX 函数是基于其他 JAX 原语实现的,因此此列表可能不完整。例如,jax.nn.relu 是基于比较实现的,而 jnp.where 在 Pallas 内核中也能工作。

数组构造函数

所有常数数组构造函数都受支持(jnp.onesjnp.zerosjnp.full)。特别是,截至今天,jax.random 模块与 Pallas 兼容。

归约

支持求和、最大值和最小值的归约,但一次只能在一个数组轴上进行。

对最后一个数组维度的归约通常是最慢的。对倒数第二个维度的归约更快,但仍比前面的维度慢。

广播

广播的性能特性与归约非常相似。总是支持除了最后两个维度之外的所有广播,且是免费的。沿着倒数第二个维度进行广播较慢,而沿着最后一个维度进行广播最慢。

重塑

如常地,所有维度除了最后两个维度的重塑都是支持的且是免费的。

唯一支持的情况是当重塑可以修改数组的最后两个维度时,即(1)某些前导维度展平到倒数第二个维度,或者(2)它添加了刚刚由归约移除的维度。

控制流程

目前,TPU 后端对控制流的支持有限。目前支持的函数有condfori_loopfor_loop。然而,在编译时,循环原语会完全展开,因此请尽量保持循环执行次数合理小。

过度使用控制流可能导致低级代码生成中的显著回归,建议尽量将多个计算密集型操作挤入一个基本块中。

管道化和块规范

原文:jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html

在本指南中,我们将介绍 TPU 中的内存空间工作原理,并展示如何在 Pallas 中编写可以将内存 I/O 与计算重叠的流水线。

#@title Imports
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np 

TPU 及其内存空间

TPU 和其 TensorCore 包括内存空间(用于存放数组的区域)、寄存器(临时存储标量和数组值的地方)和计算单元(用于处理寄存器中的值的计算单元)。下图显示了一个 TPU 的结构,其中 xy 是存储在高带宽存储器(HBM)中的数组:

让我们更详细地讨论这个图表的组成部分:

  • 内存空间:TPU 拥有高带宽内存(HBM),这通常被称为“设备内存”。还有向量内存(VMEM),一个用于存储向量和数组值的缓存,以及标量内存(SMEM),一个设计用于存储标量值的缓存。
  • 寄存器:TensorCore 拥有两种主要类型的寄存器:向量寄存器(VREGs)存储数组值,标量寄存器(SREGs)存储标量值。值可以从相应的缓存(VREG 的 VMEM 和 SREG 的 SMEM)加载到内存中。
  • 计算单元:TensorCore 包括标量单元、向量单元(VPU)和矩阵单元(MXU),用于进行数值计算。计算单元操作位于 SREG 和 VREG 中的值,并将输出值也存储在这些寄存器中。

为了在我们存储在 HBM 中的值 xy 上执行矢量化计算,我们需要:

  1. 将值 xy 复制到 VMEM 中。
  2. 从 VMEM 中加载值到 VREG 中。
  3. 使用 VPU 或 MXU 执行计算,并将输出存储在 VREG 中。
  4. 将输出 VREG 中的值存储到 VMEM 中。
  5. 将 VMEM 中的输出值复制回 HBM。

让我们实现一个 Pallas 函数来完成这些操作!

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
  # It will then copy `x` and `y` from HBM into VMEM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from VMEM back into HBM.
  return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y) 
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32) 

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 操作使用在 VMEM 中存在的 Ref。从 VMEM 的 Ref 加载会产生一个存在于 VREG 中的值。VREG 中的值的行为类似于 jax.Array,我们可以在其上使用 jnpjax.lax 操作来产生新的值,这些新值仍然存在于 VREG 中。当我们产生想要返回的值时,我们将它们存储在输出的 VMEM Ref 中。

add_matrices 函数作用于 jax.Array,并返回一个 jax.Array。在函数内部,我们将 xy 传递给 pallas_callpallas_call 负责将 xy 复制到 VMEM 中,并分配内核操作的 VMEM 缓冲区(包括分配 z_vmem_ref,输出的 VMEM 缓冲区)。内核函数运行完成后,pallas_call 还将 z_vmem_ref 中的值复制到 HBM,最终输出一个 jax.Array

使用 VMEM/SMEM 的限制

Pallas 公开了对低级内存空间(如 VMEM 和 SMEM)的访问,但编写利用它们的内核需要考虑一些因素。

  1. 内存容量。VMEM 和 SMEM 都很!v4 TPU 上的 VMEM 只有 16MiB,SMEM 的范围在几十到几百 KiB。如果我们的数组太大,甚至无法完全放入 VMEM 中。举个例子,一个 f32[2048, 2048] 数组就是 16MiB,因此我们上面的核心代码无法处理超过中等大小的数组。
  2. 内存带宽。从 HBM 和 VMEM 复制数据需要很长时间,至少与大多数计算指令相比是如此。上面的 add_matrices 函数很可能在复制 HBM 和 VMEM 之间花费的时间比执行加法本身要多。

考虑到这两个约束条件,我们必须重新思考如何提高 TPU 的性能策略。

引言:流水线

在一个行动中处理内存容量和带宽约束的流水线计算提供了一种方法。我们所说的流水线是什么意思?

目标是:并行复制到/从 HBM 和 VMEM 同时利用我们的计算单元。但在我们的程序中,这种方式相对困难,因为我们在开始进行计算之前先复制了所有的 xy,从而在复制和计算之间创建了依赖关系。

然而,如果我们可以将计算分成几个子计算(例如,当我们将两个矩阵相加时,可以将原始矩阵的“块”相加在一起),我们现在可以将其中一个子计算的复制与另一个计算的执行重叠起来。让我们通过一个简单的例子来演示:

假设我们将数组 xy 分成 x1, x2y1, y2(例如,沿着主轴进行分割,每个输入结果为两个 (256, 512) 的数组)。现在我们可以执行以下流水线计算。

  1. 复制 x1y1 到 VMEM 中。
  2. 开始将 x2y2 复制到 VMEM。
  3. 从 VMEM 加载 x1, y1 到 VREGs 中。
  4. 使用计算单元执行 z1 = x1 + y1
  5. z1 存储到 VMEM 中。
  6. 开始将 z1 从 VMEM 复制回到 HBM。
  7. 等待 x2, y2 被复制到 VMEM。
  8. 从 VMEM 加载 x2, y2 到 VREGs 中。
  9. 使用计算单元执行 z2 = x2 + y2
  10. z2 存储到 VMEM 中。
  11. 等待 z1 被复制到 HBM。
  12. 开始将 z2 从 VMEM 复制回到 HBM。
  13. 等待 z2 被复制到 HBM。

在这里进行计算时,我们总是异步复制某些内容。这意味着复制过程中的一些时间并不会浪费。

决定流水线计算效率的两个最重要的因素是 a) 我们需要执行多少浮点运算(FLOPs)和 b) 我们需要复制多少字节以执行该计算。这两者的比率(FLOPs/内存使用量)称为操作的算术强度,并确定我们的流水线是计算受限还是内存受限。

Pallas 中的流水线

我们如何在 Pallas 中实现像上面那样的管道?这似乎是一系列复杂的异步数据操作和执行内核,手动实现可能会很麻烦。不要担心!Pallas 提供了一个 API 来表达管道,而不需要太多样板文件,即通过gridBlockSpec

grid,又名循环中的内核

看看在上述流水线示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是在不同的输入上。这个泛化版本是在同一个内核上多次执行循环。pallas_call提供了一个选项来实现这一点。

循环中的迭代次数由pallas_callgrid参数指定。在概念上:

pl.pallas_call(some_kernel, grid=n)(...) 

映射到

for i in range(n):
  # do HBM -> VMEM copies
  some_kernel(...)
  # do VMEM -> HBM copies 

网格可以推广为多维,对应于嵌套循环。例如,

pl.pallas_call(some_kernel, grid=(n, m))(...) 

等价于

for i in range(n):
  for j in range(m):
    # do HBM -> VMEM copies
    some_kernel(...)
    # do VMEM -> HBM copies 

这可以推广到任意整数元组(长度为d的网格将对应于d个嵌套循环)。

BlockSpec,又称如何分块输入

为了自动管道化我们的计算,我们需要向 Pallas 提供的下一部分信息是如何对其进行分块的信息。具体来说,我们需要提供一个映射,将循环的迭代映射到操作哪些输入和输出块BlockSpec正是这两个信息。

首先,我们为我们的输入选择一个block_shape。在上面的流水线示例中,我们有(512, 512)形状的数组,并沿着主维度分成两个(256, 512)形状的数组。在这个管道中,我们的block_shape将是(256, 512)

然后,我们提供一个index_map函数,将迭代空间映射到块。具体来说,在上述管道中,第 1 次迭代我们想选择x1,第 2 次迭代我们想使用x2。可以用以下index_map表达:

def x_index_map(i):
  return (i, 0) 

然后,我们将构建BlockSpec

block_spec = pl.BlockSpec(x_index_map, (256, 512)) 

BlockSpec对于yz与对xBlockSpec将是相同的。

汇总

我们通过gridin_specsout_specs将这些参数提供给pallas_callin_specs对应于位置参数的元组,out_specs对应于输出)。

def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,))(x, y)
add_matrices_pipelined(x, y) 
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32) 

我们只需向原始函数添加了少量代码以添加自动管道,但BlockSpecgrid做了大量的重复工作!

它是如何工作的?好吧,BlockSpec提供足够的信息来开始从 HBM 到 VMEM 预取我们输入的块。例如,如果我们开始grid的第i次迭代,我们可以将i + 1传递给index_map函数,以获取下一次迭代所需的块。然后,我们可以开始这些块的异步复制。类似地,对于输出,我们可以在开始当前迭代的输出复制之前等待上一次迭代的输出复制完成。

参数化管道

在我们的内核中,参数化块形状是常见的。当优化 Pallas 内核的性能时,块大小可能是最重要的参数!它们允许我们控制管道流程(例如,选择较小的块会在我们的流水线循环中增加更多的迭代,每个迭代的工作量较小)。

此外,我们还可以沿第二维(目前仅沿第一维进行拆分)划分输入和输出。让我们编写一个更通用的内核,处理这两个特性。

def add_matrices_pipelined_2d(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
) 

处理减少

如何使用pallas_call实现类似jnp.sum的功能?具体来说,我们希望在减少维度上进行流水线处理。

以将(8, 512, 512)形状的数组减少到(512, 512)形状为例。

x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0) 
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32) 

要使用pallas_call实现这一点,我们可以使用大小为(8,)的网格,并在每次迭代i中将x[i]加载到 VMEM 中。然后我们可以将x[i]添加到输出 VMEM 缓冲区中。让我们先天真地实现这一点。

# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      naive_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
      out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
      )(x)
naive_sum(x) 
Array([[9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       ...,
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.]], dtype=float32) 

注意我们如何设置BlockSpecs:我们将(512, 512)维度完全加载到 VMEM 中(在这里没有流水线),但在块形状的index_map中每次迭代选择x的第i维度。在块形状中,我们对该维度使用None,这表示我们正在从x中选择一个单维度,我们希望在内核中将其挤压掉。因此,在 VMEM 中,x_ref也是(512, 512)形状。

out_spec使用lambda i: (0, 0)作为其index_map,指示在管道过程中o_ref保持不变。这意味着我们可以通过从中读取并向其写入来更新其值。或者可以吗?实际上有一个问题:o_ref最初是垃圾,这意味着我们将累积到垃圾中。这将导致整体函数输出不正确的值!

因此,每当我们在内核中进行减少操作时,我们需要确保初始化存储减少值的Ref。我们可以通过在迭代 0 时有条件地向out_ref写入值来实现这一点。我们可以利用辅助函数pl.when(一个方便的包装器,围绕jax.lax.condpl.program_id进行操作),查询我们在网格轴上的迭代。

def sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(axis=0) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
  o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
      out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
      )(x)
sum(x) 
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32) 

sum函数现在输出正确的值!

关于 Pallas 中减少的最后一件事是它们必须在我们网格的最小维度(最右边)中完成(在上面的示例中,我们的网格是  1 维的,因此我们在其最小维度上进行减少)。这是因为 Pallas 生成的管道不会从 HBM 读取输出。一旦将输出值写回到  HBM,就不能重新访问它。因此,您不能在具有任何重新访问的网格维度上进行减少,因此所有减少操作都需要在最右维度上进行。

Megacore 配置的 TPU

一些 TPU 芯片有两个 TensorCores,但对 JAX 用户来说,它们表现为一个设备。这被称为“megacore”。这两个独立的 TensorCores 分别拥有自己的 VMEM、VREGs、SMEM、SREGs 和计算单元,但共享 HBM

从概念上讲,Megacore 中的 TPU 行为类似于非常简单的 GPU,即只有两个线程。我们如何修改我们的内核以同时利用两个 TensorCores?

基本思想是,如果我们在计算中有尴尬地并行的维度,我们可以将这些维度分配到 TensorCores 上。我们可以通过向 pallas_call 提供一个称为 dimension_semantics 的注释来指示哪些维度是可并行化的。

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
        x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y) 
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32) 

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目都是"parallel""arbitrary""parallel" 表示对 Pallas 来说,与该维度对应的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 表示对 Pallas 来说,在这个网格维度上不能做任何假设,因此不能并行化。

通过指定 dimension_semantics,我们现在可以同时在每个 TensorCore 上执行内核。Pallas 将自动处理网格的分割。

请注意,Megacore 目前仅适用于 TPU v4 和 TPU v5p。在其他平台上提供 dimension_semantics 注释是一个空操作,但指定它将导致只使用一个 TensorCore(即使有多个可用)。

结论

在本指南中,我们讨论了如何使用 pallas_callgridBlockSpec 表达 TPU 管道。我们讨论了如何通过多维网格表达嵌套循环,并在减少开始时初始化累加器的情况下处理归约。我们还学习了如何通过向内核添加注释来处理 Megacore。

读者留给的练习:

  • 尝试实现一个 sum 内核,该内核也可以管道化其他维度
  • 还要将 add 内核和 sum 内核添加到 Megacore 支持中。
    x)
```py
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32) 

sum函数现在输出正确的值!

关于 Pallas 中减少的最后一件事是它们必须在我们网格的最小维度(最右边)中完成(在上面的示例中,我们的网格是  1 维的,因此我们在其最小维度上进行减少)。这是因为 Pallas 生成的管道不会从 HBM 读取输出。一旦将输出值写回到  HBM,就不能重新访问它。因此,您不能在具有任何重新访问的网格维度上进行减少,因此所有减少操作都需要在最右维度上进行。

Megacore 配置的 TPU

一些 TPU 芯片有两个 TensorCores,但对 JAX 用户来说,它们表现为一个设备。这被称为“megacore”。这两个独立的 TensorCores 分别拥有自己的 VMEM、VREGs、SMEM、SREGs 和计算单元,但共享 HBM

[外链图片转存中…(img-pV4vhPcr-1718951137131)]

从概念上讲,Megacore 中的 TPU 行为类似于非常简单的 GPU,即只有两个线程。我们如何修改我们的内核以同时利用两个 TensorCores?

基本思想是,如果我们在计算中有尴尬地并行的维度,我们可以将这些维度分配到 TensorCores 上。我们可以通过向 pallas_call 提供一个称为 dimension_semantics 的注释来指示哪些维度是可并行化的。

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
        x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y) 
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32) 

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目都是"parallel""arbitrary""parallel" 表示对 Pallas 来说,与该维度对应的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 表示对 Pallas 来说,在这个网格维度上不能做任何假设,因此不能并行化。

通过指定 dimension_semantics,我们现在可以同时在每个 TensorCore 上执行内核。Pallas 将自动处理网格的分割。

请注意,Megacore 目前仅适用于 TPU v4 和 TPU v5p。在其他平台上提供 dimension_semantics 注释是一个空操作,但指定它将导致只使用一个 TensorCore(即使有多个可用)。

结论

在本指南中,我们讨论了如何使用 pallas_callgridBlockSpec 表达 TPU 管道。我们讨论了如何通过多维网格表达嵌套循环,并在减少开始时初始化累加器的情况下处理归约。我们还学习了如何通过向内核添加注释来处理 Megacore。

读者留给的练习:

  • 尝试实现一个 sum 内核,该内核也可以管道化其他维度
  • 还要将 add 内核和 sum 内核添加到 Megacore 支持中。
相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
29 1
|
4月前
JAX 中文文档(九)(3)
JAX 中文文档(九)
38 0
|
4月前
|
编译器 API 异构计算
JAX 中文文档(一)(2)
JAX 中文文档(一)
60 0
|
4月前
|
机器学习/深度学习 算法 编译器
JAX 中文文档(二)(3)
JAX 中文文档(二)
44 0
|
4月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
115 0
|
4月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
56 0
|
4月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
51 0
|
4月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
50 0
|
4月前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
26 0
|
4月前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
25 0