利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 本文介绍了如何使用 PyTorch 和三元组边缘损失(Triplet Margin Loss)微调嵌入模型,详细讲解了实现细节和代码示例。

本文介绍如何使用 PyTorch 和三元组边缘损失 (Triplet Margin Loss) 微调嵌入模型,并重点阐述实现细节和代码示例。三元组损失是一种对比损失函数,通过缩小锚点与正例间的距离,同时扩大锚点与负例间的距离来优化模型。

数据集准备与处理

一般的嵌入模型都会使用Sentence Transformer ,其中的

encode()

方法可以直接处理文本输入。但是为了进行微调,我们需要采用 Transformer 库,所以就要将文本转换为模型可接受的 token IDs 和 attention masks。Token IDs 代表模型词汇表中的词或字符,attention masks 用于防止模型关注填充 tokens。

本文使用

thenlper/gte-base

模型,需要对应的 tokenizer 对文本进行预处理。该模型基于

BertModel

架构:

 BertModel(
   (embeddings): BertEmbeddings(
     (word_embeddings): Embedding(30522, 768, padding_idx=0)
     (position_embeddings): Embedding(512, 768)
     (token_type_embeddings): Embedding(2, 768)
     (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
     (dropout): Dropout(p=0.1, inplace=False)
   )
   (encoder): BertEncoder(
     (layer): ModuleList(
       (0-11): 12xBertLayer(
         (attention): BertAttention(
           (self): BertSdpaSelfAttention(
             (query): Linear(in_features=768, out_features=768, bias=True)
             (key): Linear(in_features=768, out_features=768, bias=True)
             (value): Linear(in_features=768, out_features=768, bias=True)
             (dropout): Dropout(p=0.1, inplace=False)
           )
           (output): BertSelfOutput(
             (dense): Linear(in_features=768, out_features=768, bias=True)
             (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
             (dropout): Dropout(p=0.1, inplace=False)
           )
         )
         (intermediate): BertIntermediate(
           (dense): Linear(in_features=768, out_features=3072, bias=True)
           (intermediate_act_fn): GELUActivation()
         )
         (output): BertOutput(
           (dense): Linear(in_features=3072, out_features=768, bias=True)
           (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
       )
     )
   )
   (pooler): BertPooler(
     (dense): Linear(in_features=768, out_features=768, bias=True)
     (activation): Tanh()
   )
 )

利用 Transformers 库的

AutoTokenizer

AutoModel

可以简化模型加载过程,无需手动处理底层架构和配置细节。

 fromtransformersimportAutoTokenizer, AutoModel  
 fromtqdmimporttqdm  
 tokenizer=AutoTokenizer.from_pretrained("thenlper/gte-base")  

 # 获取文本并进行标记  
 train_texts= [df_train.loc[i]['content'] foriinrange(df_train.shape[0])]  
 dev_texts= [df_dev.loc[i]['content'] foriinrange(df_dev.shape[0])]  
 test_texts= [df_test.loc[i]['content'] foriinrange(df_test.shape[0])]  

 train_tokens= []  
 train_attention_masks= []  
 dev_tokens= []  
 dev_attention_masks= []  
 test_tokens= []  
 test_attention_masks= []  

 forsentintqdm(train_texts):  
   encoding=tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')  
   train_tokens.append(encoding['input_ids'].squeeze(0))  
   train_attention_masks.append(encoding['attention_mask'].squeeze(0))  

 forsentintqdm(dev_texts):  
   encoding=tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')  
   dev_tokens.append(encoding['input_ids'].squeeze(0))  
   dev_attention_masks.append(encoding['attention_mask'].squeeze(0))  

 forsentintqdm(test_texts):  
   encoding=tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')  
   test_tokens.append(encoding['input_ids'].squeeze(0))  
   test_attention_masks.append(encoding['attention_mask'].squeeze(0))

获取 token IDs 和 attention masks 后,需要将其存储并创建一个自定义的 PyTorch 数据集。

 importrandom  
 fromcollectionsimportdefaultdict  
 importtorch  
 fromtorch.utils.dataimportDataset, DataLoader, Sampler, SequentialSampler  

 classCustomTripletDataset(Dataset):  
     def__init__(self, tokens, attention_masks, labels):  
         self.tokens=tokens  
         self.attention_masks=attention_masks  
         self.labels=torch.Tensor(labels)  
         self.label_dict=defaultdict(list)  

         foriinrange(len(tokens)):  
             self.label_dict[int(self.labels[i])].append(i)  
         self.unique_classes=list(self.label_dict.keys())  

     def__len__(self):  
         returnlen(self.tokens)  

     def__getitem__(self, index):  
         ids=self.tokens[index].to(device)  
         ams=self.attention_masks[index].to(device)  
         y=self.labels[index].to(device)  
         returnids, ams, y

由于采用三元组损失,需要从数据集中采样正例和负例。

label_dict

字典用于存储每个类别及其对应的数据索引,方便随机采样。DataLoader 用于加载数据集:

 train_loader=DataLoader(train_dataset, batch_sampler=train_batch_sampler)

其中

train_batch_sampler

是自定义的批次采样器:

 classCustomBatchSampler(SequentialSampler):  
     def__init__(self, dataset, batch_size):  
         self.dataset=dataset  
         self.batch_size=batch_size  
         self.unique_classes=sorted(dataset.unique_classes)  
         self.label_dict=dataset.label_dict  
         self.num_batches=len(self.dataset) //self.batch_size  
         self.class_size=self.batch_size//4  

     def__iter__(self):  
         total_samples_used=0  
         weights=np.repeat(1, len(self.unique_classes))  

         whiletotal_samples_used<len(self.dataset):  
             batch= []  
             classes= []  
             for_inrange(4):  
                 next_selected_class=self._select_class(weights)  
                 whilenext_selected_classinclasses:  
                   next_selected_class=self._select_class(weights)  
                 weights[next_selected_class] +=1  
                 classes.append(next_selected_class)  
                 new_choices=self.label_dict[next_selected_class]  
                 remaining_samples=list(np.random.choice(new_choices, min(self.class_size, len(new_choices)), replace=False))  
                 batch.extend(remaining_samples)  

             total_samples_used+=len(batch)  

             yieldbatch  

     def_select_class(self, weights):  
         dist=1/weights  
         dist=dist/np.sum(dist)  
         selected=int(np.random.choice(self.unique_classes, p=dist))  
         returnselected  

     def__len__(self):  
         returnself.num_batches

自定义批次采样器控制训练批次的构成,本文的实现确保每个批次包含 4 个类别,每个类别包含 8 个数据点。验证采样器则确保验证集批次在不同 epoch 间保持一致。

模型构建

嵌入模型通常基于 Transformer 架构,输出每个 token 的嵌入。为了获得句子嵌入,需要对 token 嵌入进行汇总。常用的方法包括 CLS 池化和平均池化。本文使用的

gte-base

模型采用平均池化,需要从模型输出中提取 token 嵌入并计算平均值。

 importtorch.nn.functionalasF  
 importtorch.nnasnn  

 classEmbeddingModel(nn.Module):  
     def__init__(self, base_model):  
         super().__init__()  
         self.base_model=base_model  

     defaverage_pool(self, last_hidden_states, attention_mask):  
         # 平均 token 嵌入  
         last_hidden=last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)  
         returnlast_hidden.sum(dim=1) /attention_mask.sum(dim=1)[..., None]  

     defforward(self, input_ids, attention_mask):  
         outputs=self.base_model(input_ids=input_ids, attention_mask=attention_mask)  
         last_hidden_state=outputs.last_hidden_state  
         pooled_output=self.average_pool(last_hidden_state, attention_mask)  
         normalized_output=F.normalize(pooled_output, p=2, dim=1)  
         returnnormalized_output  

 base_model=AutoModel.from_pretrained("thenlper/gte-base")  
 model=EmbeddingModel(base_model)
EmbeddingModel

类封装了 Hugging Face 模型,并实现了平均池化和嵌入归一化。

模型训练

训练循环中需要动态计算每个锚点的最难正例和最难负例。

 importnumpyasnp  

 deftrain(model, train_loader, criterion, optimizer, scheduler):  
     model.train()  
     epoch_train_losses= []  

     foridx, (ids, attention_masks, labels) inenumerate(train_loader):  
         optimizer.zero_grad()  

         embeddings=model(ids, attention_masks)  

         distance_matrix=torch.cdist(embeddings, embeddings, p=2) # 创建方形距离矩阵  

         anchors= []  
         positives= []  
         negatives= []  

         foriinrange(len(labels)):  

             anchor_label=labels[i].item()  
             anchor_distance=distance_matrix[i] # 锚点与所有其他点之间的距离  

             # 最难的正例(同一类别中最远的)  
             hardest_positive_idx= (labels==anchor_label).nonzero(as_tuple=True)[0] # 所有同类索引  
             hardest_positive_idx=hardest_positive_idx[hardest_positive_idx!=i] # 排除自己的标签  
             hardest_positive=hardest_positive_idx[anchor_distance[hardest_positive_idx].argmax()] # 最远同类的标签  

             # 最难的负例(不同类别中最近的)  
             hardest_negative_idx= (labels!=anchor_label).nonzero(as_tuple=True)[0] # 所有不同类索引  
             hardest_negative=hardest_negative_idx[anchor_distance[hardest_negative_idx].argmin()] # 最近不同类的标签  

             # 加载选择的  
             anchors.append(embeddings[i])  
             positives.append(embeddings[hardest_positive])  
             negatives.append(embeddings[hardest_negative])  

         # 将列表转换为张量  
         anchors=torch.stack(anchors)  
         positives=torch.stack(positives)  
         negatives=torch.stack(negatives)  

         # 计算损失  
         loss=criterion(anchors, positives, negatives)  
         epoch_train_losses.append(loss.item())  

         # 反向传播和优化  
         loss.backward()  
         optimizer.step()  

         # 更新学习率  
         scheduler.step()  

     returnnp.mean(epoch_train_losses)

训练过程中使用

torch.cdist()

计算嵌入间的距离矩阵,并根据距离选择最难正例和最难负例。PyTorch 的

TripletMarginLoss

用于计算损失。

结论与讨论

实践表明,Batch Hard Triplet Loss 在某些情况下并非最优选择。例如,当正例样本内部差异较大时,强制其嵌入相似可能适得其反。

本文的重点在于 PyTorch 中自定义批次采样和动态距离计算的实现。

对于某些任务,直接在分类任务上微调嵌入模型可能比使用三元组损失更有效。

https://avoid.overfit.cn/post/4b8a8e91f3274f8ca41bfff2a2d60abe

目录
相关文章
|
5天前
|
存储 人工智能 弹性计算
阿里云弹性计算_加速计算专场精华概览 | 2024云栖大会回顾
2024年9月19-21日,2024云栖大会在杭州云栖小镇举行,阿里云智能集团资深技术专家、异构计算产品技术负责人王超等多位产品、技术专家,共同带来了题为《AI Infra的前沿技术与应用实践》的专场session。本次专场重点介绍了阿里云AI Infra 产品架构与技术能力,及用户如何使用阿里云灵骏产品进行AI大模型开发、训练和应用。围绕当下大模型训练和推理的技术难点,专家们分享了如何在阿里云上实现稳定、高效、经济的大模型训练,并通过多个客户案例展示了云上大模型训练的显著优势。
|
8天前
|
存储 人工智能 调度
阿里云吴结生:高性能计算持续创新,响应数据+AI时代的多元化负载需求
在数字化转型的大潮中,每家公司都在积极探索如何利用数据驱动业务增长,而AI技术的快速发展更是加速了这一进程。
|
5天前
|
人工智能 运维 双11
2024阿里云双十一云资源购买指南(纯客观,无广)
2024年双十一,阿里云推出多项重磅优惠,特别针对新迁入云的企业和初创公司提供丰厚补贴。其中,36元一年的轻量应用服务器、1.95元/小时的16核60GB A10卡以及1元购域名等产品尤为值得关注。这些产品不仅价格亲民,还提供了丰富的功能和服务,非常适合个人开发者、学生及中小企业快速上手和部署应用。
|
14天前
|
人工智能 弹性计算 文字识别
基于阿里云文档智能和RAG快速构建企业"第二大脑"
在数字化转型的背景下,企业面临海量文档管理的挑战。传统的文档管理方式效率低下,难以满足业务需求。阿里云推出的文档智能(Document Mind)与检索增强生成(RAG)技术,通过自动化解析和智能检索,极大地提升了文档管理的效率和信息利用的价值。本文介绍了如何利用阿里云的解决方案,快速构建企业专属的“第二大脑”,助力企业在竞争中占据优势。
|
15天前
|
自然语言处理 数据可视化 前端开发
从数据提取到管理:合合信息的智能文档处理全方位解析【合合信息智能文档处理百宝箱】
合合信息的智能文档处理“百宝箱”涵盖文档解析、向量化模型、测评工具等,解决了复杂文档解析、大模型问答幻觉、文档解析效果评估、知识库搭建、多语言文档翻译等问题。通过可视化解析工具 TextIn ParseX、向量化模型 acge-embedding 和文档解析测评工具 markdown_tester,百宝箱提升了文档处理的效率和精确度,适用于多种文档格式和语言环境,助力企业实现高效的信息管理和业务支持。
3936 2
从数据提取到管理:合合信息的智能文档处理全方位解析【合合信息智能文档处理百宝箱】
|
5天前
|
算法 安全 网络安全
阿里云SSL证书双11精选,WoSign SSL国产证书优惠
2024阿里云11.11金秋云创季活动火热进行中,活动月期间(2024年11月01日至11月30日)通过折扣、叠加优惠券等多种方式,阿里云WoSign SSL证书实现优惠价格新低,DV SSL证书220元/年起,助力中小企业轻松实现HTTPS加密,保障数据传输安全。
505 3
阿里云SSL证书双11精选,WoSign SSL国产证书优惠
|
11天前
|
安全 数据建模 网络安全
2024阿里云双11,WoSign SSL证书优惠券使用攻略
2024阿里云“11.11金秋云创季”活动主会场,阿里云用户通过完成个人或企业实名认证,可以领取不同额度的满减优惠券,叠加折扣优惠。用户购买WoSign SSL证书,如何叠加才能更加优惠呢?
986 3
|
9天前
|
机器学习/深度学习 存储 人工智能
白话文讲解大模型| Attention is all you need
本文档旨在详细阐述当前主流的大模型技术架构如Transformer架构。我们将从技术概述、架构介绍到具体模型实现等多个角度进行讲解。通过本文档,我们期望为读者提供一个全面的理解,帮助大家掌握大模型的工作原理,增强与客户沟通的技术基础。本文档适合对大模型感兴趣的人员阅读。
415 17
白话文讲解大模型| Attention is all you need
|
4天前
|
数据采集 人工智能 API
Qwen2.5-Coder深夜开源炸场,Prompt编程的时代来了!
通义千问团队开源「强大」、「多样」、「实用」的 Qwen2.5-Coder 全系列,致力于持续推动 Open Code LLMs 的发展。
|
9天前
|
算法 数据建模 网络安全
阿里云SSL证书2024双11优惠,WoSign DV证书220元/年起
2024阿里云11.11金秋云创季火热进行中,活动月期间(2024年11月01日至11月30日),阿里云SSL证书限时优惠,部分证书产品新老同享75折起;通过优惠折扣、叠加满减优惠券等多种方式,阿里云WoSign SSL证书将实现优惠价格新低,DV SSL证书220元/年起。
561 5