【科普向】模型蒸馏和模型量化到底是什么???

简介: 在数字化快速发展的时代,人工智能(AI)技术已广泛应用,但大型深度学习模型对计算资源的需求日益增长,增加了部署成本并限制了其在资源有限环境下的应用。为此,研究人员提出了模型蒸馏和模型量化两种关键技术。模型蒸馏通过将大型教师模型的知识传递给小型学生模型,利用软标签指导训练,使学生模型在保持较高准确性的同时显著减少计算需求,特别适用于移动设备和嵌入式系统。模型量化则是通过降低模型权重的精度(如从32位浮点数到8位整数),大幅减少模型大小和计算量,提高运行速度,并能更好地适应低配置设备。量化分为后训练量化和量化感知训练等多种方法,各有优劣。

引言

在当今数字化快速发展的时代,人工智能(AI)技术已经渗透到我们生活的方方面面,从智能手机中的语音助手到自动驾驶车辆的安全系统,背后都离不开深度学习模型的支持。然而,随着这些模型变得越来越庞大和复杂,它们对计算资源的需求也日益增长,这不仅增加了部署成本,还限制了AI应用在资源有限环境下的广泛应用。例如,在移动设备或边缘计算场景中,由于硬件性能和功耗的限制,直接部署大型模型往往不可行。

为了解决这一问题,研究人员提出了两种关键技术——模型蒸馏(Model Distillation)与模型量化(Model Quantization),它们旨在通过不同的方式压缩复杂的深度学习模型,使得小型化后的模型能够在保持较高准确性的前提下,更高效地运行于各种平台上。这两种方法虽然侧重点不同,但都是为了实现同一个目标:让AI更加轻便、节能且易于部署

本文将以科普的形式向读者介绍模型蒸馏和模型量化的定义、工作原理及其应用场景,帮助大家理解这两项技术如何助力AI模型瘦身

模型蒸馏

模型蒸馏的概念

模型蒸馏(Model Distillation)是一种模型压缩和知识迁移的技术,旨在将一个大型、复杂且性能优异的教师模型(Teacher Model)中的知识传递给一个较小、计算效率更高的学生模型(Student Model),将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力,复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

其核心思想是利用教师模型输出的软标签(soft targets)—— 即概率分布而非硬标签(hard labels),来指导学生模型的训练。通过这种方式,学生模型不仅学习到数据的类别信息,还能够捕捉到类别之间的相似性和关系,从而提升其泛化能力。

该方法的优势在于能够在不显著损失性能的情况下,显著减少模型大小和计算需求,特别适用于资源受限的设备,如移动设备和嵌入式系统。

image.png

主要步骤

image.png

模型蒸馏通常包括以下几个步骤。

  1. 训练教师模型(Teacher Model):首先训练一个性能优异但通常较为庞大的教师模型。教师模型可以是任何高性能的深度学习模型,如深层神经网络、卷积神经网络(CNN)、Transformer等。

  2. 生成软标签(Soft Targets):使用训练好的教师模型对训练数据进行预测,获得每个样本的概率分布。这些概率分布作为软标签,包含了类别之间的相对关系信息。

  3. 训练学生模型(Student Model):设计一个较小的学生模型,并使用软标签以及硬标签共同训练。训练过程中,通常采用一个损失函数的加权组合,例如,交叉熵损失(用于硬标签)与 Kullback-Leibler 散度损失(用于软标签)。

  4. 优化与调整:通过调整温度参数、损失函数权重等超参数,优化学生模型的性能,使其尽可能接近教师模型。

关键技术与方法

软标签与温度参数

传统的训练方法通常使用硬标签,即每个样本对应一个确定的类别标签。而在模型蒸馏中,教师模型输出的是概率分布(软标签),这些概率反映了教师模型对各类别的信心程度。通过引入温度系数(temperature),可以平滑或锐化这个概率分布,从而提供更丰富的梯度信息,帮助学生模型更好地学习。

而对于温度系数,我们可以这么理解,假设有一位老师讲课速度非常快,信息密度很高,学生可能有点难以跟上。这时如果老师放慢速度,简化信息,就会让学生更容易理解。在模型蒸馏中,温度参数起到的就是类似“调节讲课速度”的作用,帮助学生模型(小模型)更好地理解和学习教师模型(大模型)的知识。专业点说就是让模型输出更加平滑的概率分布,方便学生模型捕捉和学习教师模型的输出细节。

数学表达式为:

image.png

较高的温度会使得输出分布更加平滑,能够更好地揭示类别之间的相似性,从而提供更丰富的知识给学生模型。训练过程中,通常会同时调整温度参数来优化蒸馏效果。

损失函数设计

模型蒸馏的损失函数通常由两部分组成:

  • 硬标签损失:例如交叉熵损失,用于衡量学生模型预测与真实标签之间的差异。

  • 软标签损失:例如 Kullback-Leibler 散度,用于衡量学生模型预测与教师模型输出概率分布之间的差异。

总损失可以表示为:

image.png

通过加权组合这两部分损失,可以平衡学生模型对硬标签和软标签的学习。

多任务学习与蒸馏

在某些情况下,可以将模型蒸馏与多任务学习结合,通过同时优化多个任务来提升学生模型的表现。这种方法有助于学生模型在多个方面模仿教师模型的能力。

案例分享

以下是一个完整的示例代码,从头训练教师模型并进行模型蒸馏到学生模型,我们以 CIFAR-10 数据集为例。

训练教师模型

首先,我们加载数据集并训练一个教师模型

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义教师模型
teacher_model = resnet34(pretrained=False, num_classes=10).to(device)

# 教师模型训练
print("Training Teacher Model...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-3)

for epoch in range(5):  # 使用较少的epoch演示
    teacher_model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = teacher_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Teacher Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")

# 保存教师模型
torch.save(teacher_model.state_dict(), 'teacher_model.pth')
print("Teacher Model Saved!")

训练学生模型

student_model = resnet18(pretrained=False, num_classes=10).to(device)

# 定义蒸馏损失函数
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, true_labels):
        # 软目标损失
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        # 硬目标损失
        hard_loss = self.ce_loss(student_logits, true_labels)
        # 总损失
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

distillation_loss = DistillationLoss(temperature=3.0, alpha=0.7)

# 加载教师模型权重
teacher_model.load_state_dict(torch.load('teacher_model.pth'))
teacher_model.eval()

# 蒸馏训练学生模型
print("Training Student Model with Distillation...")
optimizer = optim.Adam(student_model.parameters(), lr=1e-3)

for epoch in range(5):  # 使用较少的epoch演示
    student_model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 学生模型预测
        student_logits = student_model(images)
        # 教师模型预测(无梯度)
        with torch.no_grad():
            teacher_logits = teacher_model(images)

        # 计算蒸馏损失
        loss = distillation_loss(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Student Model Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}")

# 测试学生模型性能
print("Testing Student Model...")
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = student_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Student Model Accuracy: {100 * correct / total:.2f}%")

# 保存学生模型
torch.save(student_model.state_dict(), 'student_model.pth')
print("Student Model Saved!")

模型量化

模型量化的概念

量化是一种将较大尺寸的模型(如 LLM 或任何深度学习模型)压缩为较小尺寸的方法,比如最开始训练出的权重是32位的浮点数,但是实际使用发现用16位来表示也几乎没有什么损失,但是模型文件大小降低一般,显存使用降低一半,处理器和内存之间的通信带宽要求也降低了,这意味着更低的成本、更高的收益。

image.png

这就像按照菜谱做菜,你需要确定每种食材的重量。你可以使用一个非常精确的电子秤,它可以精确到0.01克,这固然很好,因为你可以非常精确地知道每样食材的重量。但是,如果你只是做一顿家常便饭,实际上并不需要这么高的精度,你可以使用一个简单又便宜的秤,最小刻度是1克,虽然不那么精确,但是足以用来做一顿美味的晚餐。

image.png

左侧:基础模型大小计算(单位:GB),右侧:量化后的模型大小计算(单位:GB)在上图中,基础模型 Llama 3 8B 的大小为 32 GB。经过 Int8 量化后,大小减少到 8GB(减少了 75%)。使用 Int4 量化后,大小进一步减少到 4GB(减少约 90%)。这使模型大小大幅减少。

量化还有一个好处,那就是计算的更快

现代处理器中通常都包含了很多的低精度向量计算单元,模型可以充分利用这些硬件特性,执行更多的并行运算;同时低精度运算通常比高精度运算速度快,单次乘法、加法的耗时更短。这些好处还让模型得以运行在更低配置的机器上,比如没有高性能GPU的普通办公或家用电脑、手机等移动终端。

沿着这个思路,人们继续压缩出了8位、4位、2位的模型,体积更小,使用的计算资源更少。不过随着权重精度的降低,不同权重的值会越来越接近甚至相等,这会降低模型输出的准确度和精确度,模型的性能表现会出现不同程度的下降。

量化技术有很多不同的策略和技术细节,比如如动态量化、静态量化、对称量化、非对称量化等,对于大语言模型,通常采用静态量化的策略,在模型训练完成后,我们就对参数进行一次量化,模型运行时不再需要进行量化计算,这样可以方便地分发和部署。

量化的分类

根据不同的标准,量化方法可以被划分为多种类型:

按照量化时间点分类

  • 后训练量化(Post-Training Quantization, PTQ):这是指在模型训练完成后对模型进行量化的过程。PTQ简单易行,适用于已经训练好的模型,但可能会带来一定的精度损失。

  • 量化感知训练(Quantization-Aware Training, QAT):这种方法是在训练阶段引入量化机制,让模型在训练过程中“感知”到量化的影响,从而尽量减少量化带来的精度损失。虽然训练过程更为复杂且耗时较长,但它可以在保持较高精度的同时实现模型压缩。

按照量化粒度分类

  • Per-tensor量化:整个张量或层级共享相同的量化参数(scale和zero-point)。这种方式的优点是存储和计算效率较高,但可能导致精度损失。

  • Per-channel量化:每个通道或轴都有自己的量化参数。这种方式可以更准确地量化数据,因为每个通道可以根据自身特性调整动态范围,但会增加存储需求和计算复杂度。

  • Per-group量化:将数据分组处理,每组有自己的量化参数,介于上述两者之间。

按照量化后的数值范围分类

  • 二值量化(Binary Quantization):将权重限制在+1和-1两个值之间。

  • 三值量化(Ternary Quantization):允许使用三个离散值,通常是-1、0和+1。

  • 定点数量化(Fixed-Point Quantization):最常见的是INT8和INT4,它们分别用8位和4位整数表示权重。

  • 非均匀量化(Non-uniform Quantization):根据待量化参数的概率分布计算量化节点,以适应特定的数据分布模式。

按照是否线性映射分类

  • 线性量化(Linear Quantization):采用线性映射的方式将浮点数映射到整数范围内。它可以进一步细分为对称量化和非对称量化两种形式。

  • 非线性量化(Non-linear Quantization):例如对数量化,它不是简单的线性变换,而是基于某种函数关系来进行映射。

非对称量化的实现

此处以非对称量化为例。非对称量化方法将原始张量范围(Wmin, Wmax)中的值映射到量化张量范围(Qmin, Qmax)中的值。

image.png

  • Wmin, Wmax:原始张量的最小值和最大值(数据类型:FP32,32 位浮点)。在大多数现代 LLM 中,权重张量的默认数据类型是 FP32。

  • Qmin, Qmax: 量化张量的最小值和最大值(数据类型:INT8,8 位整数)。我们也可以选择其他数据类型,如 INT4、INT8、FP16 和 BF16 来进行量化。我们将在示例中使用 INT8。

  • 缩放值(S):在量化过程中,缩放值将原始张量的值缩小以获得量化后的张量。在反量化过程中,它将量化后的张量值放大以获得反量化值。缩放值的数据类型与原始张量相同,为 FP32。

  • 零点(Z):零点是量化张量范围中的一个非零值,它直接映射到原始张量范围中的值 0。零点的数据类型为 INT8,因为它位于量化张量范围内。

  • 量化:图中的“A”部分展示了量化过程,即 [Wmin, Wmax] -> [Qmin, Qmax] 的映射。

  • 反量化:图中的“B”部分展示了反量化过程,即 [Qmin, Qmax] -> [Wmin, Wmax] 的映射。

那么,我们如何从原始张量值导出量化后的张量值呢?这其实很简单。如果你还记得高中数学,你可以很容易理解下面的推导过程。让我们一步步来(建议在推导公式时参考上面的图表,以便更清晰地理解)。

image.png
image.png

细节1:如果Z值超出范围怎么办?解决方案:使用简单的if-else逻辑将Z值调整为Qmin,如果Z值小于Qmin;若Z值大于Qmax,则调整为Qmax。这个方法在图4的图A中有详细描述。

细节2:如果Q值超出范围怎么办?解决方案:在PyTorch中,有一个名为 clamp 的函数,它可以将值调整到特定范围内(在我们的示例中为-128到127)。因此,clamp函数会将Q值调整为Qmin如果它低于Qmin,将Q值调整为Qmax如果它高于Qmax。

image.png

模型蒸馏和模型量化对比

image.png

相关文章
|
14天前
|
供应链 监控 安全
对话|企业如何构建更完善的容器供应链安全防护体系
阿里云与企业共筑容器供应链安全
171330 12
|
16天前
|
供应链 监控 安全
对话|企业如何构建更完善的容器供应链安全防护体系
随着云计算和DevOps的兴起,容器技术和自动化在软件开发中扮演着愈发重要的角色,但也带来了新的安全挑战。阿里云针对这些挑战,组织了一场关于云上安全的深度访谈,邀请了内部专家穆寰、匡大虎和黄竹刚,深入探讨了容器安全与软件供应链安全的关系,分析了当前的安全隐患及应对策略,并介绍了阿里云提供的安全解决方案,包括容器镜像服务ACR、容器服务ACK、网格服务ASM等,旨在帮助企业构建涵盖整个软件开发生命周期的安全防护体系。通过加强基础设施安全性、技术创新以及倡导协同安全理念,阿里云致力于与客户共同建设更加安全可靠的软件供应链环境。
150295 32
|
24天前
|
弹性计算 人工智能 安全
对话 | ECS如何构筑企业上云的第一道安全防线
随着中小企业加速上云,数据泄露、网络攻击等安全威胁日益严重。阿里云推出深度访谈栏目,汇聚产品技术专家,探讨云上安全问题及应对策略。首期节目聚焦ECS安全性,提出三道防线:数据安全、网络安全和身份认证与权限管理,确保用户在云端的数据主权和业务稳定。此外,阿里云还推出了“ECS 99套餐”,以高性价比提供全面的安全保障,帮助中小企业安全上云。
201961 14
对话 | ECS如何构筑企业上云的第一道安全防线
|
2天前
|
机器学习/深度学习 自然语言处理 PyTorch
深入剖析Transformer架构中的多头注意力机制
多头注意力机制(Multi-Head Attention)是Transformer模型中的核心组件,通过并行运行多个独立的注意力机制,捕捉输入序列中不同子空间的语义关联。每个“头”独立处理Query、Key和Value矩阵,经过缩放点积注意力运算后,所有头的输出被拼接并通过线性层融合,最终生成更全面的表示。多头注意力不仅增强了模型对复杂依赖关系的理解,还在自然语言处理任务如机器翻译和阅读理解中表现出色。通过多头自注意力机制,模型在同一序列内部进行多角度的注意力计算,进一步提升了表达能力和泛化性能。
|
6天前
|
存储 人工智能 安全
对话|无影如何助力企业构建办公安全防护体系
阿里云无影助力企业构建办公安全防护体系
1251 8
|
7天前
|
人工智能 自然语言处理 程序员
通义灵码2.0全新升级,AI程序员全面开放使用
通义灵码2.0来了,成为全球首个同时上线JetBrains和VSCode的AI 程序员产品!立即下载更新最新插件使用。
1291 24
|
9天前
|
机器学习/深度学习 自然语言处理 搜索推荐
自注意力机制全解析:从原理到计算细节,一文尽览!
自注意力机制(Self-Attention)最早可追溯至20世纪70年代的神经网络研究,但直到2017年Google Brain团队提出Transformer架构后才广泛应用于深度学习。它通过计算序列内部元素间的相关性,捕捉复杂依赖关系,并支持并行化训练,显著提升了处理长文本和序列数据的能力。相比传统的RNN、LSTM和GRU,自注意力机制在自然语言处理(NLP)、计算机视觉、语音识别及推荐系统等领域展现出卓越性能。其核心步骤包括生成查询(Q)、键(K)和值(V)向量,计算缩放点积注意力得分,应用Softmax归一化,以及加权求和生成输出。自注意力机制提高了模型的表达能力,带来了更精准的服务。
|
7天前
|
消息中间件 人工智能 运维
1月更文特别场——寻找用云高手,分享云&AI实践
我们寻找你,用云高手,欢迎分享你的真知灼见!
563 22
1月更文特别场——寻找用云高手,分享云&AI实践
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
|
12天前
|
人工智能 自然语言处理 API
阿里云百炼xWaytoAGI共学课DAY1 - 必须了解的企业级AI应用开发知识点
本课程旨在介绍阿里云百炼大模型平台的核心功能和应用场景,帮助开发者和技术小白快速上手,体验AI的强大能力,并探索企业级AI应用开发的可能性。