引言
在当今数字化快速发展的时代,人工智能(AI)技术已经渗透到我们生活的方方面面,从智能手机中的语音助手到自动驾驶车辆的安全系统,背后都离不开深度学习模型的支持。然而,随着这些模型变得越来越庞大和复杂,它们对计算资源的需求也日益增长,这不仅增加了部署成本,还限制了AI应用在资源有限环境下的广泛应用。例如,在移动设备或边缘计算场景中,由于硬件性能和功耗的限制,直接部署大型模型往往不可行。
为了解决这一问题,研究人员提出了两种关键技术——模型蒸馏(Model Distillation)与模型量化(Model Quantization),它们旨在通过不同的方式压缩复杂的深度学习模型,使得小型化后的模型能够在保持较高准确性的前提下,更高效地运行于各种平台上。这两种方法虽然侧重点不同,但都是为了实现同一个目标:让AI更加轻便、节能且易于部署。
本文将以科普的形式向读者介绍模型蒸馏和模型量化的定义、工作原理及其应用场景,帮助大家理解这两项技术如何助力AI模型瘦身。
模型蒸馏
模型蒸馏的概念
模型蒸馏(Model Distillation)是一种模型压缩和知识迁移的技术,旨在将一个大型、复杂且性能优异的教师模型(Teacher Model)中的知识传递给一个较小、计算效率更高的学生模型(Student Model),将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力,复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。
其核心思想是利用教师模型输出的软标签(soft targets)—— 即概率分布而非硬标签(hard labels),来指导学生模型的训练。通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。
该方法的优势在于能够在不显著损失性能的情况下,显著减少模型大小和计算需求,特别适用于资源受限的设备,如移动设备和嵌入式系统。
主要步骤
模型蒸馏通常包括以下几个步骤。
训练教师模型(Teacher Model):首先训练一个性能优异但通常较为庞大的教师模型。教师模型可以是任何高性能的深度学习模型,如深层神经网络、卷积神经网络(CNN)、Transformer等。
生成软标签(Soft Targets):使用训练好的教师模型对训练数据进行预测,获得每个样本的概率分布。这些概率分布作为软标签,包含了类别之间的相对关系信息。
训练学生模型(Student Model):设计一个较小的学生模型,并使用软标签以及硬标签共同训练。训练过程中,通常采用一个损失函数的加权组合,例如,交叉熵损失(用于硬标签)与 Kullback-Leibler 散度损失(用于软标签)。
优化与调整:通过调整温度参数、损失函数权重等超参数,优化学生模型的性能,使其尽可能接近教师模型。
关键技术与方法
软标签与温度参数
传统的训练方法通常使用硬标签,即每个样本对应一个确定的类别标签。而在模型蒸馏中,教师模型输出的是概率分布(软标签),这些概率反映了教师模型对各类别的信心程度。通过引入温度系数(temperature),可以平滑或锐化这个概率分布,从而提供更丰富的梯度信息,帮助学生模型更好地学习。
而对于温度系数,我们可以这么理解,假设有一位老师讲课速度非常快,信息密度很高,学生可能有点难以跟上。这时如果老师放慢速度,简化信息,就会让学生更容易理解。在模型蒸馏中,温度参数起到的就是类似“调节讲课速度”的作用,帮助学生模型(小模型)更好地理解和学习教师模型(大模型)的知识。专业点说就是让模型输出更加平滑的概率分布,方便学生模型捕捉和学习教师模型的输出细节。
数学表达式为:
较高的温度会使得输出分布更加平滑,能够更好地揭示类别之间的相似性,从而提供更丰富的知识给学生模型。训练过程中,通常会同时调整温度参数来优化蒸馏效果。
损失函数设计
模型蒸馏的损失函数通常由两部分组成:
硬标签损失:例如交叉熵损失,用于衡量学生模型预测与真实标签之间的差异。
软标签损失:例如 Kullback-Leibler 散度,用于衡量学生模型预测与教师模型输出概率分布之间的差异。
总损失可以表示为:
通过加权组合这两部分损失,可以平衡学生模型对硬标签和软标签的学习。
多任务学习与蒸馏
在某些情况下,可以将模型蒸馏与多任务学习结合,通过同时优化多个任务来提升学生模型的表现。这种方法有助于学生模型在多个方面模仿教师模型的能力。
案例分享
以下是一个完整的示例代码,从头训练教师模型并进行模型蒸馏到学生模型,我们以 CIFAR-10 数据集为例。
训练教师模型
首先,我们加载数据集并训练一个教师模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34
# 数据预处理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 定义教师模型
teacher_model = resnet34(pretrained=False, num_classes=10).to(device)
# 教师模型训练
print("Training Teacher Model...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-3)
for epoch in range(5): # 使用较少的epoch演示
teacher_model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = teacher_model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Teacher Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")
# 保存教师模型
torch.save(teacher_model.state_dict(), 'teacher_model.pth')
print("Teacher Model Saved!")
训练学生模型
student_model = resnet18(pretrained=False, num_classes=10).to(device)
# 定义蒸馏损失函数
class DistillationLoss(nn.Module):
def __init__(self, temperature=3.0, alpha=0.7):
super(DistillationLoss, self).__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, true_labels):
# 软目标损失
soft_loss = self.kl_loss(
nn.functional.log_softmax(student_logits / self.temperature, dim=1),
nn.functional.softmax(teacher_logits / self.temperature, dim=1)
) * (self.temperature ** 2)
# 硬目标损失
hard_loss = self.ce_loss(student_logits, true_labels)
# 总损失
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
distillation_loss = DistillationLoss(temperature=3.0, alpha=0.7)
# 加载教师模型权重
teacher_model.load_state_dict(torch.load('teacher_model.pth'))
teacher_model.eval()
# 蒸馏训练学生模型
print("Training Student Model with Distillation...")
optimizer = optim.Adam(student_model.parameters(), lr=1e-3)
for epoch in range(5): # 使用较少的epoch演示
student_model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# 学生模型预测
student_logits = student_model(images)
# 教师模型预测(无梯度)
with torch.no_grad():
teacher_logits = teacher_model(images)
# 计算蒸馏损失
loss = distillation_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Student Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")
# 测试学生模型性能
print("Testing Student Model...")
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = student_model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Student Model Accuracy: {100 * correct / total:.2f}%")
# 保存学生模型
torch.save(student_model.state_dict(), 'student_model.pth')
print("Student Model Saved!")
模型量化
模型量化的概念
量化是一种将较大尺寸的模型(如 LLM 或任何深度学习模型)压缩为较小尺寸的方法,比如最开始训练出的权重是32位的浮点数,但是实际使用发现用16位来表示也几乎没有什么损失,但是模型文件大小降低一般,显存使用降低一半,处理器和内存之间的通信带宽要求也降低了,这意味着更低的成本、更高的收益。
这就像按照菜谱做菜,你需要确定每种食材的重量。你可以使用一个非常精确的电子秤,它可以精确到0.01克,这固然很好,因为你可以非常精确地知道每样食材的重量。但是,如果你只是做一顿家常便饭,实际上并不需要这么高的精度,你可以使用一个简单又便宜的秤,最小刻度是1克,虽然不那么精确,但是足以用来做一顿美味的晚餐。
左侧:基础模型大小计算(单位:GB),右侧:量化后的模型大小计算(单位:GB)在上图中,基础模型 Llama 3 8B 的大小为 32 GB。经过 Int8 量化后,大小减少到 8GB(减少了 75%)。使用 Int4 量化后,大小进一步减少到 4GB(减少约 90%)。这使模型大小大幅减少。
量化还有一个好处,那就是计算的更快。
现代处理器中通常都包含了很多的低精度向量计算单元,模型可以充分利用这些硬件特性,执行更多的并行运算;同时低精度运算通常比高精度运算速度快,单次乘法、加法的耗时更短。这些好处还让模型得以运行在更低配置的机器上,比如没有高性能GPU的普通办公或家用电脑、手机等移动终端。
沿着这个思路,人们继续压缩出了8位、4位、2位的模型,体积更小,使用的计算资源更少。不过随着权重精度的降低,不同权重的值会越来越接近甚至相等,这会降低模型输出的准确度和精确度,模型的性能表现会出现不同程度的下降。
量化技术有很多不同的策略和技术细节,比如如动态量化、静态量化、对称量化、非对称量化等,对于大语言模型,通常采用静态量化的策略,在模型训练完成后,我们就对参数进行一次量化,模型运行时不再需要进行量化计算,这样可以方便地分发和部署。
量化的分类
根据不同的标准,量化方法可以被划分为多种类型:
按照量化时间点分类
后训练量化(Post-Training Quantization, PTQ):这是指在模型训练完成后对模型进行量化的过程。PTQ简单易行,适用于已经训练好的模型,但可能会带来一定的精度损失。
量化感知训练(Quantization-Aware Training, QAT):这种方法是在训练阶段引入量化机制,让模型在训练过程中“感知”到量化的影响,从而尽量减少量化带来的精度损失。虽然训练过程更为复杂且耗时较长,但它可以在保持较高精度的同时实现模型压缩。
按照量化粒度分类
Per-tensor量化:整个张量或层级共享相同的量化参数(scale和zero-point)。这种方式的优点是存储和计算效率较高,但可能导致精度损失。
Per-channel量化:每个通道或轴都有自己的量化参数。这种方式可以更准确地量化数据,因为每个通道可以根据自身特性调整动态范围,但会增加存储需求和计算复杂度。
Per-group量化:将数据分组处理,每组有自己的量化参数,介于上述两者之间。
按照量化后的数值范围分类
二值量化(Binary Quantization):将权重限制在+1和-1两个值之间。
三值量化(Ternary Quantization):允许使用三个离散值,通常是-1、0和+1。
定点数量化(Fixed-Point Quantization):最常见的是INT8和INT4,它们分别用8位和4位整数表示权重。
非均匀量化(Non-uniform Quantization):根据待量化参数的概率分布计算量化节点,以适应特定的数据分布模式。
按照是否线性映射分类
线性量化(Linear Quantization):采用线性映射的方式将浮点数映射到整数范围内。它可以进一步细分为对称量化和非对称量化两种形式。
非线性量化(Non-linear Quantization):例如对数量化,它不是简单的线性变换,而是基于某种函数关系来进行映射。
非对称量化的实现
此处以非对称量化为例。非对称量化方法将原始张量范围(Wmin, Wmax)中的值映射到量化张量范围(Qmin, Qmax)中的值。
Wmin, Wmax:原始张量的最小值和最大值(数据类型:FP32,32 位浮点)。在大多数现代 LLM 中,权重张量的默认数据类型是 FP32。
Qmin, Qmax: 量化张量的最小值和最大值(数据类型:INT8,8 位整数)。我们也可以选择其他数据类型,如 INT4、INT8、FP16 和 BF16 来进行量化。我们将在示例中使用 INT8。
缩放值(S):在量化过程中,缩放值将原始张量的值缩小以获得量化后的张量。在反量化过程中,它将量化后的张量值放大以获得反量化值。缩放值的数据类型与原始张量相同,为 FP32。
零点(Z):零点是量化张量范围中的一个非零值,它直接映射到原始张量范围中的值 0。零点的数据类型为 INT8,因为它位于量化张量范围内。
量化:图中的“A”部分展示了量化过程,即 [Wmin, Wmax] -> [Qmin, Qmax] 的映射。
反量化:图中的“B”部分展示了反量化过程,即 [Qmin, Qmax] -> [Wmin, Wmax] 的映射。
那么,我们如何从原始张量值导出量化后的张量值呢?这其实很简单。如果你还记得高中数学,你可以很容易理解下面的推导过程。让我们一步步来(建议在推导公式时参考上面的图表,以便更清晰地理解)。
细节1:如果Z值超出范围怎么办?解决方案:使用简单的if-else逻辑将Z值调整为Qmin,如果Z值小于Qmin;若Z值大于Qmax,则调整为Qmax。这个方法在图4的图A中有详细描述。
细节2:如果Q值超出范围怎么办?解决方案:在PyTorch中,有一个名为 clamp 的函数,它可以将值调整到特定范围内(在我们的示例中为-128到127)。因此,clamp函数会将Q值调整为Qmin如果它低于Qmin,将Q值调整为Qmax如果它高于Qmax。