目标检测实战(三):YOLO-Nano训练、测试、验证详细步骤

简介: 本文介绍了YOLO-Nano在目标检测中的训练、测试及验证步骤。YOLO-Nano是一个轻量级目标检测模型,使用ShuffleNet-v2作为主干网络,结合FPN+PAN特征金字塔和NanoDet的检测头。文章详细说明了训练前的准备、源代码下载、数据集准备、参数调整、模型测试、FPS测试、VOC-map测试、模型训练、模型测试和验证等步骤,旨在帮助开发者高效实现目标检测任务。

训练前准备

包括代码、数据集(VOC或者COCO)、调参等等…

下载源代码

受NanoDet启发的新版YOLO-Nano
网络架构分析:主干网:shufflenetv2,特征金字塔采用FPN+PAN,head用的是NanoDet的head

优化模型方式—多尺度学习、余弦退火、warmup、高分辨率、mosaic、KM聚类
损失函数:ciou_loss
预测框筛选:DIoU_nms

下载VOC和COCO数据集

这里给出一个公共数据集下载的网址:点击
修改数据集路径(voc0712.py的26行)

调参

这里主要是调整训练和epoch,no_warm_up选择False代表要采用预热模型的方式。
修改:
- epoch:config.py 里 5-8行
训练时设置use_cuda为True
训练时设置主干网的模型(yolo_nano_1.0x/yolo_nano_0.5x)
训练时设置dataset的类型VOC/COCO
如果要使用tensorboard则修改一下这里在这里插入图片描述
遇到的错误修改:
在这里插入图片描述

测试现有模式

测试图片检测和FPS

  • trained_model—修改为文件夹下存在的模型

自己根据评估那个写了个可以选择测试多张图片检测情况和FPS的代码,修改mode方式就可以。

import time
from PIL import Image
import cv2
import numpy as np
import os
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from data import *
import numpy as np
import cv2
import tools
import time
from data.voc0712 import VOCAnnotationTransform

parser = argparse.ArgumentParser(description='YOLO-Nano Detection')
parser.add_argument('-v', '--version', default='yolo_nano_1.0x',
                    help='yolo_nano_0.5x, yolo_nano_1.0x.')
parser.add_argument('-d', '--dataset', default='voc',
                    help='voc, coco-val.')
parser.add_argument('-size', '--input_size', default=416, type=int,
                    help='input_size')
parser.add_argument('--trained_model',
                    default=r'weights/voc/yolo_nano_1.0x/yolo_nano_1.0x_67.23.pth',
                    type=str, help='Trained state_dict file path to open')
parser.add_argument('--conf_thresh', default=0.1, type=float,
                    help='Confidence threshold')
parser.add_argument('--nms_thresh', default=0.50, type=float,
                    help='NMS threshold')
parser.add_argument('--visual_threshold', default=0.3, type=float,
                    help='Final confidence threshold')
parser.add_argument('--cuda', action='store_true', default=True,
                    help='use cuda.')
parser.add_argument('--diou_nms', action='store_true', default=False,
                    help='use diou nms.')

args = parser.parse_args()

def vis(img, bboxes, scores, cls_inds, thresh, class_colors, class_names, class_indexs=None, dataset='voc'):
    if dataset == 'voc':
        for i, box in enumerate(bboxes):
            cls_indx = cls_inds[i]
            xmin, ymin, xmax, ymax = box
            if scores[i] > thresh:
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), class_colors[int(cls_indx)], 1)
                cv2.rectangle(img, (int(xmin), int(abs(ymin) - 20)), (int(xmax), int(ymin)),
                              class_colors[int(cls_indx)], -1)
                mess = '%s' % (class_names[int(cls_indx)])
                cv2.putText(img, mess, (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

    elif dataset == 'coco-val' and class_indexs is not None:
        for i, box in enumerate(bboxes):
            cls_indx = cls_inds[i]
            xmin, ymin, xmax, ymax = box
            if scores[i] > thresh:
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), class_colors[int(cls_indx)], 1)
                cv2.rectangle(img, (int(xmin), int(abs(ymin) - 20)), (int(xmax), int(ymin)),
                              class_colors[int(cls_indx)], -1)
                cls_id = class_indexs[int(cls_indx)]
                cls_name = class_names[cls_id]
                # mess = '%s: %.3f' % (cls_name, scores[i])
                mess = '%s' % (cls_name)
                cv2.putText(img, mess, (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

    return img

def test(net, device, testset, transform, thresh, class_colors=None, class_names=None, class_indexs=None,
         dataset='voc',save=False,test_num=100,mode=''):
    num_images = len(testset)
    test_time,idx=[],1
    for index in range(num_images):
        print('Testing image {:d}/{:d}....'.format(index + 1, num_images))
        img, _ = testset.pull_image(index)
        img_tensor, _, h, w, offset, scale = testset.pull_item(index)

        # to tensor
        x = img_tensor.unsqueeze(0).to(device)

        t0 = time.time()
        # forward
        bboxes, scores, cls_inds = net(x)
        print("detection time used ", time.time() - t0, "s")
        if idx!=1:
            test_time.append(float(time.time() - t0))
        # scale each detection back up to the image
        max_line = max(h, w)
        # map the boxes to input image with zero padding
        bboxes *= max_line
        # map to the image without zero padding
        bboxes -= (offset * max_line)

        img_processed = vis(img, bboxes, scores, cls_inds, thresh, class_colors, class_names, class_indexs, dataset)
        if mode=='fps':
            if idx == test_num:
                break
            idx += 1
        else:
            cv2.imshow('detection', img_processed)
            cv2.waitKey(0)
            if save:
                print('Saving the' + str(index) + '-th image ...')
                save_path=r'D:\pycharm_Z\YOLO-Nano\img_files\save_detection_pic/'
                os.makedirs(os.path.dirname(save_path),exist_ok=True)
                cv2.imwrite( save_path+ str(index).zfill(6) +'.jpg', img_processed)
    return test_time

if __name__ == '__main__':
    # get device
    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    input_size = [args.input_size, args.input_size]

    # dataset
    if args.dataset == 'voc':
        print('test on voc ...')
        class_names = VOC_CLASSES
        class_indexs = None
        num_classes = 20
        anchor_size = MULTI_ANCHOR_SIZE
        dataset = VOCDetection(root=VOC_ROOT,
                               img_size=None,
                               image_sets=[('2007', 'test')],
                               transform=BaseTransform(input_size))

    elif args.dataset == 'coco-val':
        print('test on coco-val ...')
        class_names = coco_class_labels
        class_indexs = coco_class_index
        num_classes = 80
        anchor_size = MULTI_ANCHOR_SIZE_COCO
        dataset = COCODataset(
            data_dir=coco_root,
            json_file='instances_val2017.json',
            name='val2017',
            transform=BaseTransform(input_size),
            img_size=input_size[0])

    class_colors = [(np.random.randint(255), np.random.randint(255), np.random.randint(255)) for _ in
                    range(num_classes)]

    # build model
    if args.version == 'yolo_nano_0.5x':
        from models.yolo_nano import YOLONano

        backbone = '0.5x'
        net = YOLONano(device, input_size=input_size, num_classes=num_classes, anchor_size=anchor_size,
                       backbone=backbone)
        print('Let us train yolo_nano_0.5x ......')

    if args.version == 'yolo_nano_1.0x':
        from models.yolo_nano import YOLONano

        backbone = '1.0x'
        net = YOLONano(device, input_size=input_size, num_classes=num_classes, anchor_size=anchor_size,
                       backbone=backbone)
        print('Let us train yolo_nano_1.0x ......')

    else:
        print('Unknown version !!!')
        exit()

    net.load_state_dict(torch.load(args.trained_model, map_location=device))
    net.to(device).eval()

    print('Finished loading model!')
    #-------------------------------------------------------------------------#
    #   mode用于指定测试的模式:
    #   'predict'表示 多张图片预测和保存
    #   'fps'表示测试fps
    #-------------------------------------------------------------------------#
    mode = "predict"

    if mode == "predict":
        # evaluation
        test(net=net,
             device=device,
             testset=dataset,
             transform=BaseTransform(input_size),
             thresh=args.visual_threshold,
             class_colors=class_colors,
             class_names=class_names,
             class_indexs=class_indexs,
             dataset=args.dataset,
             save=True,
             mode="predict"
             )

    elif mode == "fps":
        # evaluation
        test_num=10
        time_all=test(net=net,
             device=device,
             testset=dataset,
             transform=BaseTransform(input_size),
             thresh=args.visual_threshold,
             class_colors=class_colors,
             class_names=class_names,
             class_indexs=class_indexs,
             dataset=args.dataset,
             test_num=test_num,
             mode="fps"
             )
        time_avg=sum(time_all)/len(time_all)
        print('the whole time:{}'.format(time_avg))
        print('fps:{}'.format(1/time_avg))

测试VOC-map

找到eval.py并修改测试模型路径和指定数据集路径
在这里插入图片描述
在这里插入图片描述
然后就可以运行eval.py文件。

!!!!如果遇到错误R = [obj for obj in recs[imagename] if obj[‘name’] == classname] KeyError: ‘007765’
解决办法------训练前需要将cache中的pki文件(找到voc_eval/test)以及VOCdevkit2007中annotations_cache的缓存删掉(在你的数据集里面会新建这个文件)

开始训练模型

训练文件脚本:train.py
修改完上述的地方应该就可以直接运行了(环境没问题的情况下)
在这里插入图片描述
每10个epoch会测试一下map
如果使用了tensorboard,可以在终端输入
tensorboard --logdir=D:\pycharm_Z\YOLO-Nano\log\voc\yolo_nano_1.0x\2021-09-13-13-31-02
出现一个网址进去就是
在这里插入图片描述
一个为分类loss,一个为回归loss,一个为多目标loss

测试模型

测试模型就可以使用我上面给出的代码,检测一下fps和图片检测情况
在这里插入图片描述
在这里插入图片描述

验证模型

如果是VOC数据集
验证模型的指标为Map,以及各类AP情况,通过eval.py,将模型改为你训练后的模型即可
如果是COCO数据集同理

目录
相关文章
|
2天前
|
编解码 Java 程序员
写代码还有专业的编程显示器?
写代码已经十个年头了, 一直都是习惯直接用一台Mac电脑写代码 偶尔接一个显示器, 但是可能因为公司配的显示器不怎么样, 还要接转接头 搞得桌面杂乱无章,分辨率也低,感觉屏幕还是Mac自带的看着舒服
|
4天前
|
存储 缓存 关系型数据库
MySQL事务日志-Redo Log工作原理分析
事务的隔离性和原子性分别通过锁和事务日志实现,而持久性则依赖于事务日志中的`Redo Log`。在MySQL中,`Redo Log`确保已提交事务的数据能持久保存,即使系统崩溃也能通过重做日志恢复数据。其工作原理是记录数据在内存中的更改,待事务提交时写入磁盘。此外,`Redo Log`采用简单的物理日志格式和高效的顺序IO,确保快速提交。通过不同的落盘策略,可在性能和安全性之间做出权衡。
1540 5
|
1月前
|
弹性计算 人工智能 架构师
阿里云携手Altair共拓云上工业仿真新机遇
2024年9月12日,「2024 Altair 技术大会杭州站」成功召开,阿里云弹性计算产品运营与生态负责人何川,与Altair中国技术总监赵阳在会上联合发布了最新的“云上CAE一体机”。
阿里云携手Altair共拓云上工业仿真新机遇
|
7天前
|
人工智能 Rust Java
10月更文挑战赛火热启动,坚持热爱坚持创作!
开发者社区10月更文挑战,寻找热爱技术内容创作的你,欢迎来创作!
581 22
|
4天前
|
存储 SQL 关系型数据库
彻底搞懂InnoDB的MVCC多版本并发控制
本文详细介绍了InnoDB存储引擎中的两种并发控制方法:MVCC(多版本并发控制)和LBCC(基于锁的并发控制)。MVCC通过记录版本信息和使用快照读取机制,实现了高并发下的读写操作,而LBCC则通过加锁机制控制并发访问。文章深入探讨了MVCC的工作原理,包括插入、删除、修改流程及查询过程中的快照读取机制。通过多个案例演示了不同隔离级别下MVCC的具体表现,并解释了事务ID的分配和管理方式。最后,对比了四种隔离级别的性能特点,帮助读者理解如何根据具体需求选择合适的隔离级别以优化数据库性能。
201 3
|
10天前
|
JSON 自然语言处理 数据管理
阿里云百炼产品月刊【2024年9月】
阿里云百炼产品月刊【2024年9月】,涵盖本月产品和功能发布、活动,应用实践等内容,帮助您快速了解阿里云百炼产品的最新动态。
阿里云百炼产品月刊【2024年9月】
|
11天前
|
Linux 虚拟化 开发者
一键将CentOs的yum源更换为国内阿里yum源
一键将CentOs的yum源更换为国内阿里yum源
580 5
|
23天前
|
存储 关系型数据库 分布式数据库
GraphRAG:基于PolarDB+通义千问+LangChain的知识图谱+大模型最佳实践
本文介绍了如何使用PolarDB、通义千问和LangChain搭建GraphRAG系统,结合知识图谱和向量检索提升问答质量。通过实例展示了单独使用向量检索和图检索的局限性,并通过图+向量联合搜索增强了问答准确性。PolarDB支持AGE图引擎和pgvector插件,实现图数据和向量数据的统一存储与检索,提升了RAG系统的性能和效果。
|
7天前
|
XML 安全 Java
【Maven】依赖管理,Maven仓库,Maven核心功能
【Maven】依赖管理,Maven仓库,Maven核心功能
233 3
|
9天前
|
存储 人工智能 搜索推荐
数据治理,是时候打破刻板印象了
瓴羊智能数据建设与治理产品Datapin全面升级,可演进扩展的数据架构体系为企业数据治理预留发展空间,推出敏捷版用以解决企业数据量不大但需构建数据的场景问题,基于大模型打造的DataAgent更是为企业用好数据资产提供了便利。
327 2