结合 TorchEasyRec 代码中的实际使用,从四个核心维度分析。
一、KeyedJaggedTensor(KJT)— 数据层加速
问题:推荐系统的特征天然是变长的
以 dbmtl_taobao.config 为例,16 个特征中有多值 ID 特征(如 brand、cate_id),不同样本的特征长度不同。
标准 PyTorch 做法(nn.EmbeddingBag)
# 每个特征独立处理,需要逐个调用 emb_user_id = self.emb_user_id(user_id_tensor) # 调用 1 emb_brand = self.emb_brand(brand_values, offsets) # 调用 2 emb_cate_id = self.emb_cate_id(cate_values, offsets) # 调用 3 # ... 16 个特征 = 16 次独立的 Embedding 查找 # 最后手动 torch.cat([emb_user_id, emb_brand, ...])
- 16 次独立的 CUDA kernel launch,每次 kernel launch 有固定开销(~10μs)
- 变长特征需要手动管理 padding 或 offsets
- CPU→GPU 数据传输也是 16 次独立 tensor
TorchRec 的 KJT + EmbeddingBagCollection 做法
# 所有特征打包成一个 KJT kjt = KeyedJaggedTensor( keys=["user_id", "brand", "cate_id", ...], # 16 个特征名 values=torch.tensor([...]), # 所有 ID 拼接 lengths=torch.tensor([...]), # 每个样本每个特征的长度 ) # 一次调用完成所有 Embedding 查找 result = ebc(kjt) # 单次 fused kernel
从 embedding.py 可以看到:
self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device)
加速对比
维度 |
标准 PyTorch |
TorchRec KJT + EBC |
Kernel launch |
N 次(N 个特征) |
1 次(fused) |
数据传输 |
N 个 tensor 独立传输 |
1 个 KJT 批量传输 |
内存布局 |
分散,cache 不友好 |
连续紧凑,cache 友好 |
变长处理 |
手动 padding/offsets |
内置 lengths 支持 |
二、分布式 Embedding 分片(Sharding)— 突破单卡显存瓶颈
推荐模型的 Embedding 表极大,以 config 中为例:
user_id: 1,141,730 × 16 = ~17MB brand: 461,498 × 16 = ~7MB adgroup_id: 846,812 × 16 = ~13MB
这只是 demo 数据。生产环境 Embedding 表可达数十 GB 甚至 TB 级别,远超单卡显存。
TorchRec 的自动分片
从 plan_util.py 可以看到 TorchEasyRec 使用 EmbeddingShardingPlanner 自动规划:
planner = EmbeddingShardingPlanner( topology=topology, # 集群拓扑(GPU 数量、显存、带宽) enumerator=EmbeddingEnumerator(...), # 枚举所有可能的分片方案 proposer=[DynamicProgrammingProposer(), UniformProposer()], # DP 求最优 ) plan = planner.collective_plan(model, sharders, ...) model = DistributedModelParallel(model, plan=plan, ...)
分片策略
- table_wise:整张表放在一个 GPU 上
- row_wise:大表按行切分到多个 GPU
- column_wise:按 embeddingdim 切分
- data_parallel:每个 GPU 全量复制
Planner 用动态规划算法自动找到最优分片方案,使得:
- 显存均衡分布在各卡上
- 通信量最小化
- 各卡计算负载均衡
三、TrainPipelineSparseDist — 流水线并行加速
这是 TorchRec 的杀手级特性。从 dist_util.py 可以看到 TorchEasyRec 直接使用了这个流水线。
无流水线(标准 PyTorch DDP)
Batch 1: [数据加载] → [Embedding 分发] → [前向] → [反向] → [梯度同步] Batch 2: → [数据加载] → ...
每个阶段串行等待,GPU 大量空闲。
TrainPipelineSparseDist(3 阶段流水线)
时间 → T1 T2 T3 T4 T5 Batch1: [数据加载] [Emb分发] [前向+反向] Batch2: [数据加载] [Emb分发] [前向+反向] Batch3: [数据加载] [Emb分发] [前向+反向]
- Batch N 的 Embedding 分发 与 Batch N-1 的前向/反向 同时进行
- Batch N+1 的数据加载 与 Batch N 的 Embedding 分发 同时进行
- 通信和计算完全重叠,GPU 利用率接近 100%
四、fbgemm-gpu Fused Optimizer — 优化器层加速
从 optimizer.py 可以看到:
from fbgemm_gpu import split_table_batched_embeddings_ops_training
标准 PyTorch 的 Adagrad/Adam 优化器对 Embedding 的更新流程:
前向 → 反向得到梯度 → 优化器读取梯度 → 更新权重(3 次显存读写)
fbgemm-gpu 的 fused optimizer 将梯度计算和权重更新融合在一个 kernel 中:
前向 → 反向直接在 kernel 内更新权重(1 次显存读写)
结合 apply_optimizer_in_backward,实现了 backward 和 optimizer.step 的融合,避免梯度的中间存储。
总结:端到端加速对比
优化层 |
加速来源 |
典型提升 |
数据表示 |
KJT 紧凑存储 + fused kernel |
减少 kernel launch 开销 |
Embedding 查找 |
EBC 批量 fused 查找 |
10x+ vs 逐特征查找 |
分布式分片 |
自动 DP 最优分片 |
突破单卡显存,线性扩展 |
流水线 |
通信/计算重叠 |
30-50% 吞吐提升 |
优化器 |
fbgemm fused backward |
减少 2/3 显存带宽 |
这就是为什么 TorchEasyRec 的核心数据流全部围绕 TorchRec 构建——它不只是一个 Embedding 库,而是一套从数据表示到分布式训练到优化器的完整加速栈。