嵌入表分片与哈希管理:支撑万亿参数的关键技术
1、 Hash管理及实现
如下图所示:在推荐系统中,大部分ID的原始特征都是离散型(global ids),因为其取值空间巨大且稀疏(如用户ID或物品ID可能达到百万甚至亿级别),直接作为输入会导致嵌入表维度爆炸,引发存储和计算瓶颈。常见的做法是需要将离散的ID转为连续的行号(hash indices),从而压缩特征空间,实现高效存储和查询。

Hash管理的一过程由feature map统一管理存储位置,避免直接使用原始ID导致的维度爆炸问题。在CPU侧,hashmap和FastHashmap用于DDR特征表的多级缓存管理:标准hashmap通用性强,而FastHashmap针对大特征表优化,通过减少冲突和预计算提升查询速度,但以更高内存消耗为代价。Hash表管理根据复杂度分为两种模式:
Hash表管理的两种模式
在NPU+Torch+TorchRec框架中,稀疏全局ID(global_ids)到嵌入向量(embeddings)的转换遵循global ids -> indices -> embeddings的三步流程。由于NPU对硬件哈希算子支持相对较弱,相比GPU的专用哈希处理单元,需要额外增加ids->indices的软件层转换。同时,基于hybrid_tirchrec的单层hash管理模式,提供了实现的基本功能,而随着数据量的增大,基于embcache_torchrec的双层模式通过缓存机制进一步优化大规模数据集下稀疏特征处理的性能。
单层模式
从HashEmbeddingCollection开始,通过HybridHashTable直接进行ID到嵌入向量的转换。使用IdsMapper进行ID映射和去重,将原始稀疏 ID(如用户 ID)转换为 embedding 表内部索引,然后直接生成嵌入向量。具体流程图如下图所示:

单层模式下测试用例:torchrec/hybrid_torchrec/test/test_ids_process.py · Ascend/RecSDK - 码云 - 开源中国
双层模式
在双层架构下,Hash 表的管理引入了缓存管理相关组件,以提升整体性能与一致性。整个流程中,输入数据首先通过 IdsMapper 将全局 ID 映射为嵌入表(Embedding Table)的行号,并完成去重操作。随后,通过动态管理接口 EmbeddingUpdate 的 InsertOrAssign 操作,将更新后的嵌入值插入或覆盖至主机内存中的嵌入表。
这些更新会进一步同步到 EmbTableFastHashMap,确保后续对全局 ID 的查找能够获取到最新的嵌入向量,保障数据一致性。在 Hash 表的内部实现中,以 EmbTable 作为基类,其子类 EmbTableFastHashMap 专为支持批量多 key 操作而设计。该子类通过循环调用底层 FastHashtable(用于处理单 key 操作)来提升批量处理效率,同时集成了优化器状态生成功能,便于在训练过程中维护动量、梯度等辅助信息。

2. 分桶
目前版本侧采用取余分桶策略确定每条 Embedding 的存储位置,核心逻辑是根据 Embedding 的 ID 取余后的余数,将其分配到对应桶中,最后将这些桶均分到指定的设备上。这种策略的显著优势在于适配大型 Embedding 表,能把数据均匀切分到不同设备,从根本上避免单设备负载集中的问题。
其主要操作流程如下:首先对物品索引(indices) 进行哈希处理(如取模运算 indices % blockSize),再依据哈希值将 idx 分配到N个桶(默认设置为 256 个);接着按设备数量(即 rank 数)对这些桶进行均分,确保每个 rank 分配到的 ID 数据相对连续,从而减少通信开销并提高缓存命中率。该分桶机制嵌入训练的前向流程中,在训练的每个 step,都会对当前 batch 内的数据执行分桶操作。
如下图为分布式训练中以 world_size=2 和桶数N=2 为例,数据分桶的详细流程:初始输入包括查表的indices张量(如[0, 3, 2, 7])和查表的偏移长度lengths张量(如[1, 1, 1, 1]),通过取模运算(indices % bucket)将数据分桶到bucket0和bucket1(例如索引0和2分配到bucket0,索引3和7分配到bucket1)。分桶后生成newIndices张量(如[0, 2, 3, 7])和newLengths张量(如[1, 1, 0, 0, 0, 0, 1, 1]),同时在多卡训练时,会对没有分配到当前卡的位置进行补0处理;当启用序列模式时,unbucketizePermute张量(如[0, 2, 1, 3])用于恢复原始索引顺序。这种设计通过补零和仅处理有效索引,避免了传统偏移方法的标量计算开销,提升了多卡训练效率。

分桶源码位于:torchrec/hybrid_torchrec/src/ids_process/bucketize.cpp · Ascend/RecSDK - 码云 - 开源中国
2、 嵌入表分片机制(支持大EmbTable)
分片(sharding)机制是针对大规模数据(如亿级实体、TB 级数据)的优化方案,通过将数据分割为多个部分,以提升处理效率和模型性能。如下图所示:常见分片类型包括表分片(Table-Wise)、行分片(Row-Wise)、列(Column-Wise)和DP。结合多级缓存策略可进一步优化存储资源(如平衡 NPU 显存 HBM 与 CPU 内存 DDR 的负载)。

以下内容展示了基于“ebc(Embedding Bag Collection)”模块在分布式环境下的分片配置信息。实验环境包含两个计算设备(rank:0 和 rank:1),EmbeddingBagCollection 中共有两个嵌入表:product_table 和 user_table,均采用fused融合计算内核——该内核专为大规模稀疏嵌入场景设计。本示例在不同分片策略配置下展示了不同的切片效果, 具体实现如下:
首先引入相关依赖,只需导入torchrec,torch.distributed作为dist以及其他必要的模块,以支持分布式环境下的EmbeddingBagCollection分片配置。
import torch
import torchrec
import os
from torchrec import JaggedTensor, KeyedJaggedTensor
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
import torch.distributed as dist
在分布式环境下定义和初始化一个嵌入表集合,并定义了两个嵌入表配置。
rank = int(os.environ["LOCAL_RANK"])
torch.npu.set_device(rank)
device = torch.device("npu:{}".format(rank))
dist.init_process_group(backend="hccl")
# 定义 embedding的table
ebc = torchrec.EmbeddingBagCollection(
device="meta",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=128, num_embeddings=4096,
feature_names=["product"], pooling=torchrec.PoolingType.MEAN,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=128, num_embeddings=4096,
feature_names=["user"], pooling=torchrec.PoolingType.MEAN,
)
]
)
在分布式环境中,根据不同的分片策略进行sharding_types的配置。
# 分片策略
# row_wise分片策略
constrans = {"product_table": ParameterConstraints(sharding_types=["row_wise"]), "user_table": ParameterConstraints(sharding_types=["row_wise"])}
# column_wise分片策略
# constrans = {"product_table": ParameterConstraints(sharding_types=["column_wise"]), "user_table": ParameterConstraints(sharding_types=["column_wise"])}
# table_wise分片策略
# constrans = {"product_table": ParameterConstraints(sharding_types=["table_wise"]), "user_table": ParameterConstraints(sharding_types=["table_wise"])}
#分片计划
planner = EmbeddingShardingPlanner(
topology=Topology(world_size=2,compute_device="npu",),
constraints=constrans,)
2.1、 Row-Wise
行分片(Row-Wise Sharding)是一种关键的分布式训练策略,用于解决大规模稀疏特征下 Embedding 表显存占用过高的问题。该策略将 Embedding 表按 ID 行进行水平切分,每个 GPU 仅维护部分行的参数,从而实现参数空间的有效拆分。结合 AlltoAll 输入分发与分布式梯度更新机制,TorchRec 能够在保持高吞吐的同时支持十亿级 ID 的端到端训练。该方案特别适用于具有海量离散特征的推荐模型,是构建可扩展推荐系统的基础能力之一。
"ebc"模块(通常指Embedding Bag Collection,用于处理大规模嵌入表)在行分片策略(row-wise sharding)下,分片计划会通过planner.collective_plan生成一个分布式执行方案,具体代码案例如下图所示:
# 分片策略
constrans = {
"product_table": ParameterConstraints(sharding_types=["row_wise"]),
"user_table": ParameterConstraints(sharding_types=["row_wise"])}
#模型分片计划
plan = planner.collective_plan(
module=ebc, #待分片模型
sharders=[EmbeddingBagCollectionSharder()], #分片器
pg=dist.GroupMember.WORLD) #设备群组
# print(plan)
如图所示,该输出展示了“ebc”模块在行分片策略下的具体分片配置。两个嵌入表均采用了 row_wise分片类型和 fused融合计算内核——该内核专为大规模稀疏嵌入场景设计,通过算子融合和内存优化显著提升计算效率。
每个嵌入表被均匀划分为两个部分,分别部署于 rank0 和 rank1 设备上,每个分片的大小为 [2048, 128]。
● product_table 分片详情
- rank0上的分片:偏移[0,0],大小[2048,128],负责ID 0到2047。
- rank1上的分片:偏移[2048,0],大小[2048,128],负责ID 2048到4095。
对于user_table:分片方式相同,也是按行均匀划分,每个分片2048行,128列。

2.2、 Table-Wise:纯HBM支持
表分片(Table-Wise)是一种以整个嵌入表为单位进行拆分的策略,将原始表按特定规则划分为多个独立的子表,每个子表可视为一个桶(Bucket)。这种策略尤其适合与访问频率相关的协同优化方案:例如,可先将原始表拆分为“热门物品表”(高频访问)和“冷门物品表”(低频访问),再对高频表进一步采用 Row-Wise 等方法进行更细粒度的划分,从而实现存储与计算效率的平衡。这种策略的优势在于实现简单,每个设备独立承载整张嵌入表,尤其适用于不同表之间特征规模差异大或访问模式有明显区别的场景。
"ebc"模块在表分片策略(table-wise sharding)下,分片计划会通过planner.collective_plan生成一个分布式执行方案,具体代码案例如下图所示:
# 分片策略 table_wise
constrans = {"product_table": ParameterConstraints(sharding_types=["table_wise"]),
"user_table": ParameterConstraints(sharding_types=["table_wise"])}
#模型分片计划
plan = planner.collective_plan(
module=ebc, #待分片模型
sharders=[EmbeddingBagCollectionSharder()], #分片器
pg=dist.GroupMember.WORLD) #设备群组
print(plan)
如下图所示,在该配置下,两个嵌入表被分别完整地部署于不同设备,实现了表级别的并行:
● product_table 的分片偏移(shard offsets)为 [0, 0],分片大小(shard sizes)为 [4096, 128],完整部署于 rank:0/npu:0;
● user_table 的分片偏移为 [0, 0],分片大小为 [4096, 128],完整部署于 rank:1/npu:1。

2.3、 Column-Wise(原理,昇腾暂不支持)
列分片(Column-Wise)是指将嵌入表按列(即特征维度)进行切分的策略。该方式适用于特征数量较多但每个特征的嵌入维度相对较小的场景,通过将不同特征组分布到不同设备,可实现特征级别的模型并行。
然而,由于列切分对设备之间的通信要求较高,目前该策略在昇腾硬件平台上暂未获得支持。在主流推荐系统场景中,特征维度通常远小于特征数量,因此 Column-Wise 分片在实际应用中较为少见,更常见的并行方式仍以 Row-Wise 和 Table-Wise 为主。
"ebc"模块在列分片策略(column-wise sharding)下,分片计划会通过planner.collective_plan生成一个分布式执行方案,具体代码案例如下图所示:
# 分片策略 column_wise
constrans = {"product_table": ParameterConstraints(sharding_types=["column_wise"]),
"user_table": ParameterConstraints(sharding_types=["column_wise"])}
#模型分片计划
plan = planner.collective_plan(
module=ebc, #待分片模型
sharders=[EmbeddingBagCollectionSharder()], #分片器
pg=dist.GroupMember.WORLD) #设备群组
print(plan)
该图片展示了在分布式环境下 product_table和 user_table的分片配置信息。从图中可见,两个表均采用了 column_wise(列分片)的分片策略和 fused融合计算内核,并分别完整部署于两个不同计算设备上。
以 product_table为例,其分片配置如下:
● 该表被整体分配并部署于 rank 0(npu:0)设备上;
● shard offsets 为 [0, 0],表示分片从原嵌入表的第0行、第0列开始;
● shard sizes 为 [4096, 128],代表该分片尺寸为 4096 行 × 128 列,即整张表完整存储于当前设备,负责全部 ID 范围(0 到 4095)的嵌入向量。
同样地,user_table以完全相同的方式整体部署于 rank 1(npu:1) 设备上,分片偏移为 [0, 0],分片大小同样为 [4096, 128],承担整张表的嵌入查询任务。

2.4、 DP(Data-parallel)
数据并行模式下,每个 NPU 设备存储全量 Embedding 数据。在反向传播过程中,需通过 AllReduce 操作聚合所有 NPU 上的参数梯度,确保各设备模型参数一致。该模式适用于小 Embedding 表(内存占用低),因无需数据拆分,可简化通信逻辑,但受限于 HBM 容量,不适合大规模数据场景。
该图片展示了在分布式环境下 product_table和 user_table的在DP情况下的配置信息,由图中可以看出,每个卡上都存储着两个表所有的数据。

总结:
在分布式深度学习训练中,特别是处理大规模嵌入表(如推荐系统场景),合理的分片策略是优化存储、负载和通信开销的核心手段。针对不同规模和数据特性的嵌入表,可参考一下原则:
● 大型表优先用 Row-Wise + 分桶 + 多级缓存,平衡负载与存储;
● 热门 / 冷门表拆分结合 Table-Wise 与 Row-Wise,优化高频数据访问;
● 小表适用 DP 模式,简化通信并保证参数一致性。