准入淘汰策略详解
准入淘汰策略主要应用于推荐系统的多级缓存架构中,是推荐模型训练过程中缓存管理的核心组成部分。它位于训练流程的缓存管理层,负责动态管理嵌入表(embedding table)中特征的存储和生命周期。
在大规模的稀疏场景下的模型训练,部分特征频次较低,无法为模型的训练提供有效的信息,同时会造成内存的浪费。也存在部分长时间不进行更新的特征,由于其时效性低,干扰训练结果。因此我们针对业务场景提出合适的准入淘汰机制,对动态场景下的词表进行精细化管理。
推荐系统的多级缓存架构中,准入(Admission)和淘汰(Eviction)功能是确保缓存高效、精准工作的核心机制。它们共同决定了“什么样的数据值得进入缓存”以及“当缓存满了,什么数据应该被请出去”。在多级缓存架构中,准入与淘汰策略的主要目标如下:
a. 最大化缓存命中率:让尽可能多的请求命中缓存,减少直达底层数据库或推荐计算引擎的压力。
b. 最小化访问延迟:确保用户能快速拿到推荐结果,提升体验。
c. 最优使用缓存资源:内存是昂贵且有限的,必须存放最有价值的数据
目前我们的准入淘汰功能在CPU-Ps中进行控制,前向阶段进行准入计数和准入判断,反向阶段进行淘汰计数,在save阶段或固定step阶段执行淘汰。
准入淘汰机制的代码流程主要位于:torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.cpp · Ascend/RecSDK - 码云 - 开源中国
1、 准入策略
在推荐系统和机器学习中,特征准入机制是管理稀疏特征(如用户ID或物品ID)的关键技术,用于控制哪些特征键(key)被纳入模型训练或推理,以优化内存使用、计算效率和模型性能。默认采用NoneFilters,即所有key均准入。如下所示为目前支持三种准入策略:
基于访问次数: key出现次数>=阈值,则该key被准入。这类似于低频过滤,用于淘汰不常见或噪声特征,减少存储开销。
基于概率: 生成随机数 < 阈值([0,1]之间)则key被准入。这样避免模型偏差于高频特征,提升多样性。
基于show-click: 结合用户行为数据(如展示次数show_cnt和点击次数click_cnt),计算ShowClick score 。
ShowClick_score = alpha * show_cnt + beta * click_cnt(alpha,beta为自定义参数)。
如果得分超过阈值,则key被准入。这种机制适用于CTR预估或广告推荐场景,优先保留有高交互概率的特征,从而提升模型相关性。
| 准入策略 | 准入计数 | 准入判断 |
|---|---|---|
| 基于访问次数 | 前向阶段进行计数更新 | 前向查表阶段进行准入判断 |
| 基于概率 | 无需更新计数 | 前向查表阶段进行准入判断;若key已存在,不会重复判断。 |
| 基于show-click | 前向阶段进行ShowClick计数更新 | 前向查表阶段进行准入判断(ShowClick score > 阈值),则准入; |
这些机制可单独或组合使用(如先基于访问次数过滤,再通过概率机制探索),并根据业务需求调整阈值。整体上,准入策略旨在平衡特征覆盖率与计算效率,尤其在处理高基数稀疏数据时至关重要。
准入score的更新(UpdateScore)和判断逻辑放在反向中执行,根据当前batch使用的所有id对应的show/click值(作为map存储在 ShowClickFilter 中)对其score进行更新。其中反向过程从cpu-ps查询出来的emb,需要分下面三种情况:
1) 直接从cpu-ps查询到,说明该id已经准入并且未被淘汰,此时进行score更新之后返回emb
2) 未能从cpu-ps查询到,则依据准入score进行判断,符合条件则准入(创建Value值之后给该Value的score做更新) 然后创建emb的值,cbu-ps插入该id并返回emb
3) 没能在cpu-ps查询到并且判断也不能准入,返回的是nullptr,此时在代码上游判断value是否为null来决定emb是否是默认值
准入机制统一方案:
准入状态记录:在全局唯一性(global unique)处理完成后,调用封装好的计数接口,准确记录每个key的准入状态。
首先定义与准入相关的配置参数,
// 准入相关配置 int32_t admitThreshold // 准入阈值,默认值表示未开启准入 std::unordered_map<int64_t, FeatureRecord> featureRecordMap; // 准入,记录key次数 not_admitted_default_value // 未准入ids对应embedding值调用封装好的计数接口 StatisticsKeyCount ,准确记录每个key的准入状态。
""" const at::Tensor& batchKeys, // 当前批次所有特征键的连续内存张量 const torch::Tensor& offset, // 每个key的偏移量 const at::Tensor& batchKeyCounts // 当前批次涉及哪些嵌入表的索引 int64_t tableIndex """ void EmbcacheManager::StatisticsKeyCount(const at::Tensor& batchKeys, const torch::Tensor& offset, const at::Tensor& batchKeyCounts, int64_t tableIndex) { // 未开启local unique时,counts为空tensor,处理时默认key对应count为1 bool isCountDataEmpty = batchKeyCounts.numel() == 0; auto* featureDataPtr = batchKeys.data_ptr<int64_t>(); // 获取原始数据指针 auto* countDataPtr = batchKeyCounts.data_ptr<int64_t>(); auto* offsetDataPtr = offset.data_ptr<int64_t>(); int64_t start = offsetDataPtr[tableIndex]; int64_t end = offsetDataPtr[tableIndex + 1]; TORCH_CHECK(end <= batchKeys.numel()) featureFilters[tableIndex].StatisticsKeyCount(featureDataPtr, countDataPtr, start, end, isCountDataEmpty); }定义计数接口的内部实现,通过判断目前的key是否已经存在,进行计数。
""" const int64_t* featureDataPtr, // 特征数据指针 const int64_t* countDataPtr, // 计数数据指针 int64_t startIndex, // 起始索引 int64_t endIndex, // 结束索引 bool isCountDataEmpty // 计数数据模式标志 """ void FeatureFilter::StatisticsKeyCount(const int64_t* featureDataPtr, const int64_t* countDataPtr, int64_t startIndex, int64_t endIndex, bool isCountDataEmpty) { for (int64_t i = startIndex; i < endIndex; ++i) { auto feature = *(featureDataPtr + i); // 顺序内存访问 auto count = isCountDataEmpty ? 1 : *(countDataPtr + i); auto iter = featureRecordMap.find(feature); if (iter != featureRecordMap.end()) { // 特征已存在,更新现有计数 iter->second.count += count; } else { // 新特征,创建记录 FeatureRecord featureRecord = {count}; featureRecordMap[feature] = featureRecord; } } }准入条件判断:在
ComputeSwapInfo阶段,基于准入计数接口的数值动态评估key是否符合准入条件,确保逻辑一致性。class FeatureFilter{} # 特征过滤和准入控制 AdmitAndEvictConfig : # 准入和淘汰参数配置 """ admit_threshold: 准入阈值控制 evict_threshold: 基于时间的淘汰策略 evict_step_interval: 特征过滤功能 "" if (embConfigs[idx].admitAndEvictConfig.IsAdmitEnabled()) { featureFilters[idx].CountFilter(keyPtr, offsetPerKey[i], offsetPerKey[i + 1]); }CountFilter函数,对特征进行准入判断的内部实现
void FeatureFilter::CountFilter(int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex) { // 准入检查,将未准入的特征置为-1 auto thresholdCount = static_cast<uint64_t>(admitThreshold); for (int64_t i = startIndex; i < endIndex; ++i) { auto feature = *(featureDataPtr + i); auto iter = featureRecordMap.find(feature); if (iter != featureRecordMap.end() && iter->second.count < thresholdCount) { *(featureDataPtr + i) = INVALID_KEY; } } }- CPU侧查表优化:在CPU侧的Embedding查表操作中引入不计数开关,避免CPU侧查询干扰准入与淘汰的计数规则,保证计数准确性。
- NPU侧数据恢复:在NPU执行
lookup操作后,根据准入结果将未通过准入的key自动恢复为默认值,维持数据完整性。
准入计数的计数策略
主要注意点是适配local unique。在local unique之后,ids的数量需要进行all2all通信,才能准确统计到所有卡上的ids数量。
local unique时,会统计ids出现的次数。并通过自定义KJT: KeyedJaggedTensorWithCount 实现count的all2all。
- 对ids进行计数
- count处理(开启local unique时需要)
- input_dist阶段,调用 bucketize_kjt_before_all2all 做本地去重时会返回counts;
- 自定义kjt: KeyedJaggedTensorWithCount,对counts数据进行all2all;
- 将count记录到C++ map中
- post_input_dist阶段,do_unique_hash_out:
- 从 kjt中counts(开local unique),或者构造 empty tensor(ids count=1);
- 调用ids_mapper -> cache_mgr.statistics_key_count() 记录
- post_input_dist阶段,do_unique_hash_out:
- count处理(开启local unique时需要)
- 对未准入的ids处理
- compute_and_output_dist阶段:
- embeddings = lookup()
- _reset_embedding_for_not_admitted_ids(embeddings, xx)
将ids对应index为0的位置的embedding重置为not_admitted_default_value
- compute_and_output_dist阶段:
2、 淘汰策略
在推荐系统的稀疏特征表管理中,淘汰策略是优化内存使用和模型性能的核心机制之一。当系统资源受限或需要维护特征表的高效性时,合理的淘汰策略能够自动移除低价值特征,确保高频重要特征得到保留。默认采用NoneShrink,即所有key均不淘汰。如下所示为目前支持的五种淘汰策略:
基于固定step: 训练过程中,key在连续多个训练步长(step)内未参与训练(即未被访问或更新),则触发淘汰。
基于访问次数: 每个key维护一个version变量,表明特征的新鲜程度。 若version >= threshold,且代表这个key很久没有用 过,则予以淘汰。
基于L2范数: 计算所有key的L2范数(即各维度平方和的平方根)L2_score,若L2_score < threshold(反映特征权重整体较小),则认为该特征对模型贡献微弱,触发淘汰。L2范数计算公式如下图所示:

基于时间和频次: 时间time:key被尝试淘汰的次数;频次freq:key参与训练的总次数;
在出现频次最低的k个key中,如果这个key属于尝试淘汰频率最高的k个key,则淘汰这个key;
基于show-click: 针对广告或推荐场景,key的评分由展示次数(show_cnt)和点击次数(click_cnt)加权计算:
ShowClick score = alpha * show_cnt + beta * click_cnt(alpha,beta为自定义参数)。
在得分最低的k个key中,进一步筛选点击率(click-through rate)最低的key进行淘汰。
| 淘汰策略 | 淘汰计数 | 淘汰判断 |
|---|---|---|
| 基于固定step | 反向阶段,将key对应的global_step_version参数更新到最新的global_step | save阶段或者固定step之后,调用doshrink,判断key是否达到淘汰标准,然后执行淘汰,删除被淘汰key的所有记录 |
| 基于访问次数 | 反向阶段将训练的key的version更新为0;淘汰执行阶段,所有key的version += 1,即尝试淘汰的次数 | |
| 基于L2范数 | 无需更新计数 | |
| 基于时间和频次 | 反向阶段更新time和freq,将涉及到key的time=0,freq+=1;淘汰执行阶段 time+=1; | |
| 基于show-click | 淘汰执行阶段,所有key的version进行衰减:version *= decay_rate,每一次的淘汰都会触发一次衰减。 |
更新showClick score参数,调用doshrink,获取所有需要淘汰的key,执行淘汰,删除被淘汰key的所有记录 |
对于淘汰而言: 淘汰的触发时机是在Saver.save()时调用,通过表的大小是否大于safebucket_size判断是否需要淘汰。
Embedding.shrink时,触发淘汰的逻辑是先计算当前表中有多少已经准入的 kv。因为show/click策略,无论是否准入都会将kv插入表中,如果不准入,那么表中的 value 的 param 为null,因此这里需要进行一次遍历,得到表中实际准入的kv个数,std:floor(n*(1 - gamma))来计算得到需要淘汰的个数,之后按照分数进行排序淘汰。
NPU上的准入和淘汰的实现跟GPU相比有差异,NPU上的准入调用时机是每次查表对key进行计数(频率)确认是否进入且信息的更新在前向过程触发。淘汰的机制是在模型save阶段或固定step阶段,检测key的淘汰状态。对于需淘汰的key,立即从embcache_mgr中清除其记录,确保淘汰及时生效。
准入机制统一方案:
首先,定义与淘汰机制相关的配置参数,
EvictFeatureRecord evictFeatureRecord; 淘汰记录管理
// 淘汰相关配置
uint64_t evictThreshold // unit: second
uint64_t evictStepInterval // 淘汰间隔步数
uint64_t recordTsBatchId
std::time_t latestTimestamp // 当前表最新的时间戳,用于判断淘汰
std::unordered_map<int64_t, std::time_t> timestampRecordMap; // 淘汰,记录key时间戳
准入key的淘汰信息更新:在input_dist之前,调用RecordTimestamp,记录准入判断后,key进入embedding表的时间戳;
""" const at::Tensor& batchKeys, // 批次特征key值 const std::vector<int64_t>& offsetPerKey, // 每个表的偏移量 const at::Tensor& timestamps, // 时间戳数据 const std::vector<int32_t>& tableIndices // 表索引列表 """ void EmbcacheManager::RecordTimestamp(const at::Tensor& batchKeys, const std::vector<int64_t>& offsetPerKey, const at::Tensor& timestamps, const std::vector<int32_t>& tableIndices) { TimeCost recordTimestampTC; const auto* keyPtr = batchKeys.data_ptr<int64_t>(); const auto* timestampsPtr = timestamps.data_ptr<int64_t>(); const std::vector<int32_t>& curTableIndices = tableIndices.empty() ? embTableIndies_ : tableIndices; for (int64_t i = 0; i < embNum; ++i) { int32_t idx = curTableIndices[i]; if (embConfigs[idx].admitAndEvictConfig.IsEvictEnabled()) { # 记录访问时间,并对时间戳进行判断是否对数据进行淘汰 featureFilters[idx].RecordTimestamp(keyPtr, offsetPerKey[i], offsetPerKey[i + 1], timestampsPtr); } } }基于时间戳的内部淘汰机制,代码具体实现:
// 时间戳记录机制 void FeatureFilter::RecordTimestamp(const int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex, const int64_t* timestampDataPtr) { // 定期淘汰执行机制 // 因记录timestamp和计算swap info存在步数差异,因此记录timestamp时需同时记录淘汰keys if (recordTsBatchId > 0 && (recordTsBatchId + 1) % evictStepInterval == 0) { FeatureEvict(); // 执行特征淘汰 } } void FeatureFilter::FeatureEvict() //淘汰 { std::vector<int64_t>& evictKeys = evictFeatureRecord.GetEvictKeys(); auto tempEvictThreshold = static_cast<std::time_t>(evictThreshold); for (auto iter : timestampRecordMap) { auto feature = iter.first; if (feature == -1) { continue; } bool needEvict = false; if (latestTimestamp - iter.second > tempEvictThreshold) { evictKeys.emplace_back(feature); needEvict = true; } } // 淘汰掉的key从timestampRecordMap中移出 bool isAdmitEnabled = admitThreshold != -1; for (auto feature : evictKeys) { timestampRecordMap.erase(feature); if (isAdmitEnabled) { // 开启准入时同时移出准入map中的key featureRecordMap.erase(feature); } } }淘汰触发:触发淘汰之后,需要从embcache_mgr中删除;
void EmbcacheManager::EvictFeatures() { TimeCost evictFeaturesTC; size_t evictKeyCount = 0; for (int32_t i = 0; i < embNum; ++i) { if (!embConfigs[i].admitAndEvictConfig.IsEvictEnabled()) { continue; } // 获取当前表要淘汰的keys const std::vector<int64_t>& evictFeatures = featureFilters[i].evictFeatureRecord.GetEvictKeys(); // 调用swapManager删除映射信息 // 删除embeddingTables中的embedding待对应step的swap out emb update执行完成后触发 swapManagers[i].RemoveKeys(evictFeatures); featureFilters[i].evictFeatureRecord.SetSwapCount(swapCount); evictKeyCount += evictFeatures.size(); } }
基于时间的淘汰机制代码流程
- input_dist阶段:记录timestamp数据
EmbCacheShardedEmbeddingCollection -> input_dist() -> FeatureFilter::RecordTimestamp() - 通过EmbcacheManager::EvictFeatures()实现淘汰
- 1 通过FeatureFilter::FeatureEvict获取淘汰的key
- 2 删除换入换出映射:swapManagers[i].RemoveKeys
- 3 记录swapCount;
- 4 当HostEmbeddingUpdate执行到 swapCount 步数时,删除table embedding;
3、 准入淘汰测试用例
RecSDK提供一个测试脚本,用于测试分布式环境下的嵌入缓存(EmbCache)的准入(admit)和逐出(evict)机制。
torchrec/torchrec_embcache/tests/acc_test/test_feature_filter.py · Ascend/RecSDK - 码云 - 开源中国
全部配置和基础类定义,其中enable_admit和enable_evict分别控制准入和淘汰机制的启用。
import os
import torch
import torch.distributed as dist
from torchrec_embcache.distributed.configs import EmbCacheEmbeddingConfig, AdmitAndEvictConfig
from typing import List
from dataclasses import dataclass
# 全局配置
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "2"))
LOOP_TIMES = 500
EVICT_STEP_INTERVAL = LOOP_TIMES // 4
BATCH_NUM = LOOP_TIMES
_SAVE_PATH = "save_dir"
@dataclass
class ExecuteConfig:
world_size: int
table_num: int
embedding_dims: List[int]
num_embeddings: List[int]
sharding_type: str
lookup_len: int
device: str
enable_admit: bool
enable_evict: bool
为准入淘汰功能配置参数,准入淘汰配置设置。
def setup_embedding_configs(config):
admit_threshold = 2 if config.enable_admit else default_config.admit_threshold
evict_threshold = 2000_0000 if config.enable_evict else default_config.evict_threshold
embedding_configs = []
for i in range(config.table_num):
admit_evict_config = AdmitAndEvictConfig(
admit_threshold=admit_threshold, # 准入阈值
evict_threshold=evict_threshold, # 淘汰阈值
evict_step_interval=EVICT_STEP_INTERVAL # 淘汰执行间隔
)
# 创建嵌入表配置
ec_config = EmbCacheEmbeddingConfig(admit_and_evict_config=admit_evict_config)
embedding_configs.append(ec_config)
return embedding_configs
淘汰机制核心实现,基于时间戳识别不活跃的嵌入项并重置其状态。
class TestModel:
def _record_timestamp_info_cpu(self, batch, table_num, batch_id):
"""记录每个嵌入ID的时间戳信息"""
# 记录每个嵌入ID的最后访问时间
for table_index in range(table_num):
for ids, ts in zip(values_per_table, ts_per_table):
self.timestamps_for_table[table_index][ids] = ts
self.last_timestamp_for_table[table_index] = max(
self.last_timestamp_for_table[table_index], ts
)
def _evict_embedding_cpu(self, evict_threshold, embeddings, opt, batch_id):
"""执行淘汰操作"""
for table_index in range(table_num):
evict_ids = []
last_ts = self.last_timestamp_for_table[table_index]
# 找出需要淘汰的ID(长时间未访问的)
for ids, ts in self.timestamps_for_table[table_index].items():
if last_ts - ts > evict_threshold: # 超过淘汰阈值
evict_ids.append(ids)
# 重置被淘汰的嵌入和优化器状态
for ids in evict_ids:
embeddings[table_name].weight[ids].data.copy_(init_emb)
slot_tensor[ids].data.copy_(init_slot)
测试执行流程,根据不同的测试场景(准入/淘汰)执行相应的验证逻辑。
def execute(rank: int, config: ExecuteConfig):
# 设置嵌入配置(包含准入淘汰参数)
embedding_configs = setup_embedding_configs(config)
test_model = TestModel(rank, config.world_size, config.device)
# 仅淘汰场景功能测试
if not config.enable_admit and config.enable_evict:
golden_results = test_model.cpu_golden_loss(embedding_configs, ...)
# 主测试
test_results = test_model.test_loss(embedding_configs, ...)
# 准入计数检查
if config.enable_admit and not config.enable_evict:
_check_admit_key_count(...)
# 结果验证(淘汰场景)
if not config.enable_admit and config.enable_evict:
assert torch.allclose(golden_results, test_results)
准入机制(Admit)的验证功能,用于检查准入计数是否正确。 通过遍历所有批次数据,统计每个嵌入ID的访问次数。
def _check_admit_key_count(data_loader_golden, embedding_configs: List[EmbCacheEmbeddingConfig], rank):
# 1 手动统计key count
iter_ = iter(data_loader_golden)
loop_time = 0
table_key_count = [{} for _ in range(len(embedding_configs))]
while loop_time < LOOP_TIMES:
loop_time += 1
batch: Batch = next(iter_, None)
kjt = batch.sparse_features
values = kjt.values()
offset_per_key = kjt.offset_per_key()
for i in range(len(offset_per_key) - 1):
values_per_table = values[offset_per_key[i]: offset_per_key[i + 1]]
for ids in values_per_table:
ids = ids.item()
if ids in table_key_count[i]:
table_key_count[i][ids] = table_key_count[i][ids] + 1
else:
table_key_count[i][ids] = 1
# 2 读取保存目录下的key count
key_file_saved = os.path.join(_SAVE_PATH, "table{}", "rank{}".format(rank), "key", "slice.data")
count_file_saved = os.path.join(_SAVE_PATH, "table{}", "rank{}".format(rank), "admit_count", "slice.data")
table_key_count_saved = [{} for _ in range(len(embedding_configs))]
for i in range(len(embedding_configs)):
key_data = np.fromfile(key_file_saved.format(i), dtype=np.int64).reshape(-1)
count_data = np.fromfile(count_file_saved.format(i), dtype=np.int64).reshape(-1)
for index in range(key_data.shape[0]):
ids = key_data[index]
count = count_data[index]
table_key_count_saved[i][ids] = count
# 3 对比数据
length_equal = all(len(table_key_count[i]) == len(table_key_count_saved[i]) for i in range(len(embedding_configs)))
4、 准入淘汰计数策略流程图
图片中描述了Embedding查表流程,包括local unique、分桶、AlltoAll、glocal unique等步骤。然后是准入方案,涉及ShowClick计数、准入计数写入、获取换入换出的key等。最后是淘汰方案,包括获取淘汰的key和删除embcache_mgr中淘汰key的信息。
