目标检测知识蒸馏---以SSD为例【附代码】上

简介: 笔记

在上一篇文章中有讲解以分类网络为例的知识蒸馏【分类网络知识蒸馏】,这篇文章将会针对目标检测网络进行蒸馏。


知识蒸馏是一种不改变网络结构模型压缩方法。这里的压缩需要和量化与剪枝进行区分,并不是严格意义上的压缩。这里将要讲的蒸馏是离线式蒸馏中的逻辑蒸馏【特征部分蒸馏以后会讲】,也是一种常用的方法,他是将已经训练好的teacher model对student model进行蒸馏。


teacher model是一个在精度表现上优良的模型,而student model往往是精度差一些,但推理速度高的模型。如果要采用这种蒸馏方式,需要注意的是两个Model的网络结构需要相似【因此可以将改进前后的model建立这种关系】。而实现部分最最重要的部分是建立蒸馏的Loss函数。


在目标检测中主要有两个任务,一个是分类,一个是边界的回归,前者的蒸馏是比较容易的,关键是在后者,这也是蒸馏的一个难点。


我们先来看一下SSD代码中的MultiBoxloss部分详解。


MultiBoxloss


SSD中分类loss采用CrossEntropy,边界loss采用平滑L1。具体公式和网络算法原理参考论文,这里不在多说。


loss参数说明

参数说明:


self.use_gpu:是否采用gpu训练


self.num_classes:训练类的数量【在SSD中num_classes是自己的类数量+背景类】


self.threshold:阈值,默认0.5


self.background_label:背景类标签,默认为0


self.encode_target:target编码


self.use_prior_for_matching:利用先眼眶做匹配


self.do_neg_mining:True,负样本挖掘


self.negpos_ratio:负样本比例,设置为3【正负样本比例为1:3】


self.variance :方差

class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
        Compute Targets:
            1) Produce Confidence Target Indices by matching  ground truth boxes
               with (default) 'priorboxes' that have jaccard index > threshold parameter
               (default threshold: 0.5).
            2) Produce localization target by 'encoding' variance into offsets of ground
               truth boxes and their matched  'priorboxes'.
            3) Hard negative mining to filter the excessive number of negative examples
               that comes with using a large number of default bounding boxes.
               (default negative:positive ratio 3:1)
        Objective Loss:
            L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
            Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
            weighted by α which is set to 1 by cross val.
            Args:
                c: class confidences,
                l: predicted boxes,
                g: ground truth boxes
                N: number of matched default boxes
            See: https://arxiv.org/pdf/1512.02325.pdf for more details.
        """
    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, negatives_for_hard=100.0):
        super(MultiBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.negatives_for_hard = negatives_for_hard
        self.variance = Config['variance']

loss forward部分

predictions:类型为tuple,网络的输出内容,包含:位置预测,分类置信度预测以及prior boxes预测。


predictions[0]的shape为:[batch,8732,4]


predictions[1]的shape为:[batch,8732,num_classes]


predictions[2]的shape为:[8732,4]


注:8732:以输入大小300*300为例,将在6个head部分产生8732个先眼眶.


8732= 38*38*4 + 19*19*6 + 10*10*6 + 5*5*6 + 3*3*6 + 1*1*4


target:包含了标注的数据集真实的boxes坐标以及label信息。是一个列表,列表的长度等于batch的数量,每个列表中的元素shape为[num_objs,5],num_objs表示你当前图像中标注的目标数量,5=boxes信息+label信息。

    def forward(self, predictions, targets):
        """Multibox Loss
                Args:
                    predictions (tuple): A tuple containing loc preds, conf preds,
                    and prior boxes from SSD net.
                        conf shape: torch.size(batch_size,num_priors,num_classes)
                        loc shape: torch.size(batch_size,num_priors,4)
                        priors shape: torch.size(num_priors,4)
                    pred_t (tuple): teacher's predictions
                    targets (tensor): Ground truth boxes and labels for a batch,
                        shape: [batch_size,num_objs,5] (last idx is the label).
                """
        #--------------------------------------------------#
        #   取出预测结果的三个值:回归信息,置信度,先验框
        #--------------------------------------------------#
        loc_data, conf_data, priors = predictions

创建两个全零张量用来做先验框和真实框的匹配,这里的num等于batch_size

loc_t = torch.zeros(num, num_priors, 4).type(torch.FloatTensor)
conf_t = torch.zeros(num, num_priors).long()

遍历每个batch,idx是batch的索引,truths是获取到的真实值的boxes信息。labels是获取到的当前图像中是什么类。


truths:tensor([[0.2333, 0.2067, 0.6967, 1.0000]], device='cuda:0')
labels:tensor([0.], device='cuda:0')
        for idx in range(num):
            # 获得真实框与标签
            truths = targets[idx][:, :-1]
            labels = targets[idx][:, -1]
            if(len(truths)==0):
                continue
            # 获得先验框
            defaults = priors
            #--------------------------------------------------#
            #   利用真实框和先验框进行匹配。
            #   如果真实框和先验框的重合度较高,则认为匹配上了。
            #   该先验框用于负责检测出该真实框。
            #--------------------------------------------------#
            match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)

defaults是先验框,接下来是和真实框进行标签匹配。


标签匹配函数match

       传入参数:threshold,truths[真实boxes],defaults[先验框],variance[方差],labels[真实标签],loc_t[前面创建的全零张量,用来存放匹配后的boxes信息], conf_t[用来存储匹配后的置信度分类信息],idx[当前batch的索引 ]。


1.先计算所有先验框和真实框的的重合程度。


       box_a是就是上面的truths,box_b是先验框【注意先验框中的boxes形式是center_x,center_y,w,h,需要先转成左上角和右下角的形式】。最终就可以计算出IOU。

def jaccard(box_a, box_b):
    #-------------------------------------#
    #   返回的inter的shape为[A,B]
    #   代表每一个真实框和先验框的交矩形
    #-------------------------------------#
    inter = intersect(box_a, box_b)
    #-------------------------------------#
    #   计算先验框和真实框各自的面积
    #-------------------------------------#
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    #-------------------------------------#
    #   每一个真实框和先验框的交并比[A,B]
    #-------------------------------------#
    return inter / union

因此得到的overlaps是计算的所有先验框和真实框的iou,shape为[1,8732]。


overlaps = jaccard(
        truths,
        point_form(priors)
    )

接下来是通过max函数获得这8732个先验框中与真实框匹配度最好的框和索引【就相当于可以把这个匹配的最好的认为是ground truth】。


可以得到iou最高的是0.6904,是第8711号先验框。


best_prior_overlap:tensor([[0.6904]], device='cuda:0')
best_prior_idx:tensor([[8711]], device='cuda:0')


用于保证每个真实框都有一个先验框与之匹配。

for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)

将truths扩充成8732.


matches = truths[best_truth_idx]

获取标签


conf = labels[best_truth_idx] + 1

获取背景类,通过设置的iou阈值进行过滤。


conf[best_truth_overlap < threshold] = 0

进行边界框的编码【其实就是将真实框和先验框进行匹配】。


def encode(matched, priors, variances):
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    g_cxcy /= (variances[0] * priors[:, 2:])
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    return torch.cat([g_cxcy, g_wh], 1)

获得的loc shape为【8732,4】

loc = encode(matches, priors, variances)

将编码后的loc放入前面定义loc_t中,conf也是如此。

获得正样本。

# 所有conf_t>0的地方,代表内部包含物体
pos = conf_t > 0

此时的pos形式如下,shape为【batch,8732】:


tensor([[False, False, False,  ..., False, False,  True],

       [False, False, False,  ..., False, False, False]], device='cuda:0')


求和得到每个图像内有多少正样本。这就可以计算出在所有的batch中的所有batch*8732个框中有多少框内包含目标。


num_pos = pos.sum(dim=1, keepdim=True)


相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
目录
相关文章
|
6月前
|
机器学习/深度学习 编解码 人工智能
|
3月前
|
PyTorch 算法框架/工具 计算机视觉
目标检测实战(二):YoloV4-Tiny训练、测试、评估完整步骤
本文介绍了使用YOLOv4-Tiny进行目标检测的完整流程,包括模型介绍、代码下载、数据集处理、网络训练、预测和评估。
209 2
目标检测实战(二):YoloV4-Tiny训练、测试、评估完整步骤
|
7月前
|
测试技术 计算机视觉
【YOLOv8性能对比试验】YOLOv8n/s/m/l/x不同模型尺寸大小的实验结果对比及结论参考
【YOLOv8性能对比试验】YOLOv8n/s/m/l/x不同模型尺寸大小的实验结果对比及结论参考
|
7月前
|
机器学习/深度学习 计算机视觉
【YOLO性能对比试验】YOLOv9c/v8n/v6n/v5n的训练结果对比及结论参考
【YOLO性能对比试验】YOLOv9c/v8n/v6n/v5n的训练结果对比及结论参考
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
PyTorch搭建图卷积神经网络(GCN)完成对论文分类及预测实战(附源码和数据集)
PyTorch搭建图卷积神经网络(GCN)完成对论文分类及预测实战(附源码和数据集)
374 1
|
机器学习/深度学习 算法 数据可视化
详细解读GraphFPN | 如何用图模型提升目标检测模型性能?
详细解读GraphFPN | 如何用图模型提升目标检测模型性能?
203 0
|
并行计算 固态存储 Linux
|
固态存储 计算机视觉 索引
|
固态存储 计算机视觉
【24】目标检测模型SSD的搭建及其训练与测试
【24】目标检测模型SSD的搭建及其训练与测试
271 0
【24】目标检测模型SSD的搭建及其训练与测试

热门文章

最新文章