为什么使用 TorchRec 训练和推理更快

简介: 本文结合TorchEasyRec实践,从四大维度解析推荐系统加速:1)KeyedJaggedTensor统一变长特征,实现Embedding批量融合查找;2)自动分布式分片突破单卡显存瓶颈;3)TrainPipelineSparseDist流水线并行,重叠通信与计算;4)fbgemm-gpu融合优化器,减少显存访问。端到端提升训练效率与扩展性。

结合 TorchEasyRec 代码中的实际使用,从四个核心维度分析。

一、KeyedJaggedTensor(KJT)— 数据层加速

问题:推荐系统的特征天然是变长的

以 dbmtl_taobao.config 为例,16 个特征中有多值 ID 特征(如 brand、cate_id),不同样本的特征长度不同。

标准 PyTorch 做法(nn.EmbeddingBag)

# 每个特征独立处理,需要逐个调用
emb_user_id = self.emb_user_id(user_id_tensor)      # 调用 1
emb_brand   = self.emb_brand(brand_values, offsets)   # 调用 2
emb_cate_id = self.emb_cate_id(cate_values, offsets)  # 调用 3
# ... 16 个特征 = 16 次独立的 Embedding 查找
# 最后手动 torch.cat([emb_user_id, emb_brand, ...])
  • 16 次独立的 CUDA kernel launch,每次 kernel launch 有固定开销(~10μs)
  • 变长特征需要手动管理 padding 或 offsets
  • CPU→GPU 数据传输也是 16 次独立 tensor

TorchRec 的 KJT + EmbeddingBagCollection 做法

# 所有特征打包成一个 KJT
kjt = KeyedJaggedTensor(
    keys=["user_id", "brand", "cate_id", ...],  # 16 个特征名
    values=torch.tensor([...]),                    # 所有 ID 拼接
    lengths=torch.tensor([...]),                   # 每个样本每个特征的长度
)
# 一次调用完成所有 Embedding 查找
result = ebc(kjt)  # 单次 fused kernel

embedding.py 可以看到:

self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device)

加速对比

维度

标准 PyTorch

TorchRec KJT + EBC

Kernel launch

N 次(N 个特征)

1 次(fused)

数据传输

N 个 tensor 独立传输

1 个 KJT 批量传输

内存布局

分散,cache 不友好

连续紧凑,cache 友好

变长处理

手动 padding/offsets

内置 lengths 支持


二、分布式 Embedding 分片(Sharding)— 突破单卡显存瓶颈

推荐模型的 Embedding 表极大,以 config 中为例:

user_id:     1,141,730 × 16 = ~17MB
brand:         461,498 × 16 = ~7MB
adgroup_id:    846,812 × 16 = ~13MB

这只是 demo 数据。生产环境 Embedding 表可达数十 GB 甚至 TB 级别,远超单卡显存。

TorchRec 的自动分片

plan_util.py 可以看到 TorchEasyRec 使用 EmbeddingShardingPlanner 自动规划:

planner = EmbeddingShardingPlanner(
    topology=topology,                    # 集群拓扑(GPU 数量、显存、带宽)
    enumerator=EmbeddingEnumerator(...),  # 枚举所有可能的分片方案
    proposer=[DynamicProgrammingProposer(), UniformProposer()],  # DP 求最优
)
plan = planner.collective_plan(model, sharders, ...)
model = DistributedModelParallel(model, plan=plan, ...)

分片策略

  • table_wise:整张表放在一个 GPU 上
  • row_wise:大表按行切分到多个 GPU
  • column_wise:按 embeddingdim 切分
  • data_parallel:每个 GPU 全量复制

Planner 用动态规划算法自动找到最优分片方案,使得:

  • 显存均衡分布在各卡上
  • 通信量最小化
  • 各卡计算负载均衡

三、TrainPipelineSparseDist — 流水线并行加速

这是 TorchRec 的杀手级特性。从 dist_util.py 可以看到 TorchEasyRec 直接使用了这个流水线。

无流水线(标准 PyTorch DDP)

Batch 1: [数据加载] → [Embedding 分发] → [前向] → [反向] → [梯度同步]
Batch 2:                                                    → [数据加载] → ...

每个阶段串行等待,GPU 大量空闲。

TrainPipelineSparseDist(3 阶段流水线)

时间 →  T1        T2        T3        T4        T5
Batch1: [数据加载] [Emb分发]  [前向+反向]
Batch2:           [数据加载] [Emb分发]  [前向+反向]
Batch3:                     [数据加载] [Emb分发]  [前向+反向]
  • Batch N 的 Embedding 分发 与 Batch N-1 的前向/反向 同时进行
  • Batch N+1 的数据加载 与 Batch N 的 Embedding 分发 同时进行
  • 通信和计算完全重叠,GPU 利用率接近 100%

四、fbgemm-gpu Fused Optimizer — 优化器层加速

optimizer.py 可以看到:

from fbgemm_gpu import split_table_batched_embeddings_ops_training

标准 PyTorch 的 Adagrad/Adam 优化器对 Embedding 的更新流程:

前向 → 反向得到梯度 → 优化器读取梯度 → 更新权重(3 次显存读写)

fbgemm-gpu 的 fused optimizer 将梯度计算和权重更新融合在一个 kernel 中:

前向 → 反向直接在 kernel 内更新权重(1 次显存读写)

结合 apply_optimizer_in_backward,实现了 backward 和 optimizer.step 的融合,避免梯度的中间存储。


总结:端到端加速对比


优化层

加速来源

典型提升

数据表示

KJT 紧凑存储 + fused kernel

减少 kernel launch 开销

Embedding 查找

EBC 批量 fused 查找

10x+ vs 逐特征查找

分布式分片

自动 DP 最优分片

突破单卡显存,线性扩展

流水线

通信/计算重叠

30-50% 吞吐提升

优化器

fbgemm fused backward

减少 2/3 显存带宽

这就是为什么 TorchEasyRec 的核心数据流全部围绕 TorchRec 构建——它不只是一个 Embedding 库,而是一套从数据表示到分布式训练到优化器的完整加速栈。

相关文章
|
5天前
|
人工智能 JSON 监控
Claude Code 源码泄露:一份价值亿元的 AI 工程公开课
我以为顶级 AI 产品的护城河是模型。读完这 51.2 万行泄露的源码,我发现自己错了。
4030 11
|
15天前
|
人工智能 JSON 机器人
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
本文带你零成本玩转OpenClaw:学生认证白嫖6个月阿里云服务器,手把手配置飞书机器人、接入免费/高性价比AI模型(NVIDIA/通义),并打造微信公众号“全自动分身”——实时抓热榜、AI选题拆解、一键发布草稿,5分钟完成热点→文章全流程!
11616 135
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
|
4天前
|
人工智能 数据可视化 安全
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
本文详解如何用阿里云Lighthouse一键部署OpenClaw,结合飞书CLI等工具,让AI真正“动手”——自动群发、生成科研日报、整理知识库。核心理念:未来软件应为AI而生,CLI即AI的“手脚”,实现高效、安全、可控的智能自动化。
1415 7
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
|
6天前
|
人工智能 自然语言处理 数据挖掘
零基础30分钟搞定 Claude Code,这一步90%的人直接跳过了
本文直击Claude Code使用痛点,提供零基础30分钟上手指南:强调必须配置“工作上下文”(about-me.md+anti-ai-style.md)、采用Cowork/Code模式、建立标准文件结构、用提问式提示词驱动AI理解→规划→执行。附可复制模板与真实项目启动法,助你将Claude从聊天工具升级为高效执行系统。
|
5天前
|
人工智能 定位技术
Claude Code源码泄露:8大隐藏功能曝光
2026年3月,Anthropic因配置失误致Claude Code超51万行源码泄露,意外促成“被动开源”。代码中藏有8大未发布功能,揭示其向“超级智能体”演进的完整蓝图,引发AI编程领域震动。(239字)
2307 9