yolov4 LOSS代码详解【附代码】下

简介: 笔记

get_ignore:判断预测结果和真实结果的重合度


上面的get_target我们获得了gt在特征层的各种信息,它返回了y_true(记录当前特征层是由第几个anchor预测以及目标落在了哪个cell处);noobj_mask记录了哪些cell是没有目标的,box_loss_scale是各个目标对应于特征层比例值。


get_ignore这个函数是对预测结果进行解码,判断预测结果和真实值的重合度,如果重合度大则可以忽略,因为这部分说明预测的很准了。


该函数需要传入l(特征层索引),x,y,h,w(预测的box信息,即model输出的),targets(真实值),scaled_anchors(缩放后的anchors),in_h,in_w(特征层尺寸),noobj_mask(无目标的mask,shape[batch_size,3,19,19])。

        #---------------------------------------------------------------#
        #   将预测结果进行解码,判断预测结果和真实值的重合程度
        #   如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
        #   作为负样本不合适
        #----------------------------------------------------------------#
        noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)

先把完整代码列出来:

    def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
        #-----------------------------------------------------#
        #   计算一共有多少张图片
        #-----------------------------------------------------#
        bs = len(targets)
        FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
        LongTensor  = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
        #-----------------------------------------------------#
        #   生成网格,先验框中心,网格左上角
        #-----------------------------------------------------#
        grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
            int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type(FloatTensor)
        grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
            int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type(FloatTensor)
        # 生成先验框的宽高
        scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
        anchor_w = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([0]))
        anchor_h = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([1]))
        anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
        anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
        #-------------------------------------------------------#
        #   计算调整后的先验框中心与宽高
        #-------------------------------------------------------#
        pred_boxes_x    = torch.unsqueeze(x + grid_x, -1)
        pred_boxes_y    = torch.unsqueeze(y + grid_y, -1)
        pred_boxes_w    = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
        pred_boxes_h    = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
        pred_boxes      = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
        for b in range(bs):           
            #-------------------------------------------------------#
            #   将预测结果转换一个形式
            #   pred_boxes_for_ignore      num_anchors, 4
            #-------------------------------------------------------#
            pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
            #-------------------------------------------------------#
            #   计算真实框,并把真实框转换成相对于特征层的大小
            #   gt_box      num_true_box, 4
            #-------------------------------------------------------#
            if len(targets[b]) > 0:
                batch_target = torch.zeros_like(targets[b])
                #-------------------------------------------------------#
                #   计算出正样本在特征层上的中心点
                #-------------------------------------------------------#
                batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
                batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
                batch_target = batch_target[:, :4]
                #-------------------------------------------------------#
                #   计算交并比
                #   anch_ious       num_true_box, num_anchors
                #-------------------------------------------------------#
                anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
                #-------------------------------------------------------#
                #   每个先验框对应真实框的最大重合度
                #   anch_ious_max   num_anchors
                #-------------------------------------------------------#
                anch_ious_max, _    = torch.max(anch_ious, dim = 0)
                anch_ious_max       = anch_ious_max.view(pred_boxes[b].size()[:3])
                noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
        return noobj_mask, pred_boxes

生成网格

通过in_w,in_h我们可以划分cell,grid_x和grid_y的shape均为【batch_size,3,19,19】.

可以这样理解一下,每个anchor都对应19 * 19个网格。

        #-----------------------------------------------------#
        #   生成网格,先验框中心,网格左上角
        #-----------------------------------------------------#
        grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
            int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type(FloatTensor)
        grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
            int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type(FloatTensor)

grid_x:【这里我只取第一个batch和第一个anchor为例】


tensor([


       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.],

       [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,14., 15., 16., 17., 18.]], device='cuda:0')


grid_y:(大家猛的一看是不是看这种形式有点怪,实际这个grid_y是需要和上面grid_x进行对应的)


tensor([


       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,0.,  0.,  0.,  0.,  0.],

       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,1.,  1.,  1.,  1.,  1.],

       [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,2.,  2.,  2.,  2.,  2.],

       [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,3.,  3.,  3.,  3.,  3.],

       [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,4.,  4.,  4.,  4.,  4.],

       [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 5.,  5.,  5.,  5.,  5.],

       [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,6.,  6.,  6.,  6.,  6.],

       [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,7.,  7.,  7.,  7.,  7.],

       [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,8.,  8.,  8.,  8.,  8.],

       [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,9.,  9.,  9.,  9.,  9.],

       [10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,10., 10., 10., 10., 10.],

       [11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,11., 11., 11., 11., 11.],

       [12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,12., 12., 12., 12., 12.],

       [13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13.,13., 13., 13., 13., 13.],

       [14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14.,14., 14., 14., 14., 14.],

       [15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15., 15.,15., 15., 15., 15., 15.],

       [16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,16., 16., 16., 16., 16.],

       [17., 17., 17., 17., 17., 17., 17., 17., 17., 17., 17., 17., 17., 17.,17., 17., 17., 17., 17.],

       [18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18.,18., 18., 18., 18., 18.]], device='cuda:0')


获得先验框的w和h:

        # 生成先验框的宽高
        scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
        anchor_w = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([0]))
        anchor_h = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([1]))
        anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
        anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)

计算调整后的先验框中心和宽高:

x,y,w,h是model输出的box信息【预测值】,通过x+grid_x可以获得预测box位于19 * 19网格的哪个坐标处。得到的pred_boxes就是我们得到的在网格图上的预测box。shape为【batch_size,3,19,19,4】

        #-------------------------------------------------------#
        #   计算调整后的先验框中心与宽高
        #-------------------------------------------------------#
        pred_boxes_x    = torch.unsqueeze(x + grid_x, -1)
        pred_boxes_y    = torch.unsqueeze(y + grid_y, -1)
        pred_boxes_w    = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
        pred_boxes_h    = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
        pred_boxes      = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)

对上面的pred_boxes转换一些shape的形式,pred_boxes[0]的shape是【3,19,19,4】,对其进行平铺,变成【3*19*19,4】=【1083,4】,也就是我们得到的pred_boxes_for_ignore。

        for b in range(bs):           
            #-------------------------------------------------------#
            #   将预测结果转换一个形式
            #   pred_boxes_for_ignore      num_anchors, 4
            #-------------------------------------------------------#
            pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)

这里再创建一个batch_target的全0tensor,功能和get_target函数中的batch_target一样,记录每个batch中所有目标真实值信息。


       

#-------------------------------------------------------#
            #   计算真实框,并把真实框转换成相对于特征层的大小
            #   gt_box      num_true_box, 4
            #-------------------------------------------------------#
            if len(targets[b]) > 0:
                batch_target = torch.zeros_like(targets[b])
                #-------------------------------------------------------#
                #   计算出正样本在特征层上的中心点
                #-------------------------------------------------------#
                batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
                batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
                batch_target = batch_target[:, :4]

targets的第一个batch中,出现了5个目标。 这里再说一下target表示的内容(主要怕大家看到这里又忘记了),分别表示center_x, center_y, w, h, class。


tensor([[0.2352, 0.1637, 0.2928, 0.3273, 0.0000],

       [0.2097, 0.5896, 0.1135, 0.4490, 0.0000],

       [0.8569, 0.5296, 0.1382, 0.2500, 0.0000],

       [0.6867, 0.5395, 0.1201, 0.2763, 0.0000],

       [0.6061, 0.1637, 0.2911, 0.3273, 0.0000]], device='cuda:0'),


由于上面的target信息是归一化到0~1间,我们需要映射到特征层上,进过上述操作得到batch_target:


tensor([


       [ 4.4688,  3.1094,  5.5625,  6.2188],

       [ 3.9844, 11.2031,  2.1562,  8.5312],

       [16.2812, 10.0625,  2.6250,  4.7500],

       [13.0469, 10.2500,  2.2813,  5.2500],

       [11.5156,  3.1094,  5.5312,  6.2188]], device='cuda:0')


计算交并比:

anch_ious的shape为【5,1083】,5指的就是当前batch中出现目标的数量,1083=3*19*19(当前特征层有多少anchors)。也就是说我们现在获得了所有真实值和预测值的box iou【和前面的get_target注意区分,get_target iou是先验框和真实框的iou,现在算的iou是真实值的预测值的】。


           

#-------------------------------------------------------#
                #   计算交并比
                #   anch_ious       num_true_box, num_anchors
                #-------------------------------------------------------#
                anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)

   

计算得到每个目标的真实框和预测框最大的iou了,并reshape成(3,19,19),就相当于知道每个cell内真实框和预测框的iou。


进而可以在nooj_mask进行筛选,将iou大于阈值(0.5)的置0.[表示这些地方有目标,就是正样本]


           

#-------------------------------------------------------#
                #   每个先验框对应真实框的最大重合度
                #   anch_ious_max   num_anchors
                #-------------------------------------------------------#
                anch_ious_max, _    = torch.max(anch_ious, dim = 0)
                anch_ious_max       = anch_ious_max.view(pred_boxes[b].size()[:3])
                noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0

y_true的shape是【batch_size,3,19,19,5+classes】。5+classes指的是x,y.w.h,有无目标,当前为什么类。因此将y_true[...,4]置为1表示有目标,利用sum可以获得有多少正样本。


 

loss        = 0
        obj_mask    = y_true[..., 4] == 1
        n           = torch.sum(obj_mask)   # 有多少正样本

loss计算


接下来就是计算loss。


先计算ciou【具体计算过程和可视化在我另一篇文章有写】。【这里计算的是预测框和真实框的ciou】


1-ciou就是边界回归的loss_loc.


然后利用二分类交叉熵计算分类损失loss_cls。利用obj_mask在预测结果中进行筛选,pred_cls的shape为【batch_size,3,19,19,num_class[conf]】,obj_mask的shape【4,3,19,19】,表示在所有cell筛选出有目标的cell。y_true也是利用obj_mask进行筛选y_true[...,5:]表示对应的类。


 

if n != 0:
            #---------------------------------------------------------------#
            #   计算预测结果和真实结果的ciou
            #   ciou.shape = [batch_size,3,feature_w,feature_h]
            #----------------------------------------------------------------#
            ciou        = self.box_ciou(pred_boxes, y_true[..., :4])
            # loss_loc    = torch.mean((1 - ciou)[obj_mask] * box_loss_scale[obj_mask])
            loss_loc    = torch.mean((1 - ciou)[obj_mask])  # 边界回归
            loss_cls    = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))  # 分类回归(只是判断有没有目标)
            loss        += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
        if self.focal_loss:
            ratio       = torch.where(obj_mask, torch.ones_like(conf) * self.alpha, torch.ones_like(conf) * (1 - self.alpha)) * torch.where(obj_mask, torch.ones_like(conf) - conf, conf) ** self.gamma
            loss_conf   = torch.mean((self.BCELoss(conf, obj_mask.type_as(conf)) * ratio)[noobj_mask.bool() | obj_mask])
        else: 
            loss_conf   = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])  # 置信度回归
        loss        += loss_conf * self.balance[l] * self.obj_ratio
        # if n != 0:
        #     print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
        return loss


所以此时loc_loss=0.5211,loss_cls=0.8044.


置信度损失:


loss_conf   = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])  # 置信度回归

三者损失相加就是最后的loss损失。


上面就是针对loss部分的代码进行解析,可以更好的理解实现过程。有助于大家的理解。


目录
相关文章
|
8月前
|
计算机视觉
如何理解focal loss/GIOU(yolo改进损失函数)
如何理解focal loss/GIOU(yolo改进损失函数)
|
8月前
|
机器学习/深度学习
大模型训练loss突刺原因和解决办法
【1月更文挑战第19天】大模型训练loss突刺原因和解决办法
1178 1
大模型训练loss突刺原因和解决办法
|
8月前
|
机器学习/深度学习 计算机视觉
YOLOv8改进 | 2023 | InnerIoU、InnerSIoU、InnerWIoU、FocusIoU等损失函数
YOLOv8改进 | 2023 | InnerIoU、InnerSIoU、InnerWIoU、FocusIoU等损失函数
429 0
|
PyTorch 算法框架/工具
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
380 1
|
PyTorch 算法框架/工具 机器学习/深度学习
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
458 0
|
数据可视化 计算机视觉 异构计算
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(二)
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(二)
400 1
|
数据可视化 计算机视觉
Backbone | What?没有Normalize也可以起飞?全新Backbone之NF-ResNet(文末获取论文与源码)
Backbone | What?没有Normalize也可以起飞?全新Backbone之NF-ResNet(文末获取论文与源码)
101 1
|
数据挖掘 计算机视觉 网络架构
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(一)
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(一)
270 0
|
缓存 算法 PyTorch
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
3229 0
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
|
数据可视化 PyTorch 算法框架/工具

热门文章

最新文章

相关实验场景

更多