NeurIPS 2022 | 四分钟内就能训练目标检测器,商汤基模型团队是怎么做到的?

简介: NeurIPS 2022 | 四分钟内就能训练目标检测器,商汤基模型团队是怎么做到的?


机器之心编辑部

来自商汤的基模型团队和香港大学等机构的研究人员提出了一种大批量训练算法 AGVM,该研究已被NeurIPS 2022接收。


本文提出了一种大批量训练算法 AGVM (Adaptive Gradient Variance Modulator),不仅可以适配于目标检测任务,同时也可以适配各类分割任务。AGVM 可以把目标检测的训练批量大小扩大到 1536,帮助研究人员四分钟训练 Faster R-CNN,3.5 小时把 COCO 刷到 62.2 mAP,均打破了目标检测训练速度的世界纪录。



在当前的机器学习社区中,有三个普遍的趋势。首先,神经网络模型会越来越大。在 NLP 领域中最大规模的模型已经达到了上万亿级别。在视觉领域,最大规模的模型也达到了三百亿的量级。其次,训练的数据集也变得越来越大。比如,ImageNet 21k 和谷歌的 JFT 数据集都具有相当规模的数据集。另外,由于数据集变得越来越大,训练 SOTA 模型的开销越来越大。


因此,提升训练效率就变得愈发重要。而分布式训练因为其适应于数据并行、模型并行和流水线并行的加速训练方法的同时,也具备较高的 Deep Learning 通信效率而被广泛认为是一个有效的解决方案。


随着大模型时代的到来,目标检测器的训练速度越来越成为学术界和工业界的瓶颈,例如,在 COCO 的标准 setting 上把 mAP 训到 62 以上大概需要三天的时间,算上调试成本,这在业界几乎是不可接受的。那么,我们能不能把这个训练时间压到小时级别呢?事实上,在图片分类和自然语言处理任务上,先前的研究人员借助 32K 的批量大小(batch size),只需 14 分钟就可以完成 ImageNet 的训练,76 分钟完成 Bert 的训练。但是,在目标检测领域,还很欠缺这类研究,导致研究人员无法充分利用当前的算力,数据集和大模型。


大批量训练算法 AGVM 便是这个问题的最佳解决方案之一。为了支持如此大批量的训练,同时保持模型的训练精度,本研究提出了一套全新的训练算法,根据密集预测不同模块的梯度方差(gradient variance),动态调整每一个模块的学习率。作者在大量的密集预测网络和数据集上进行了实验,并且证实了该方法的合理性。

 

方法介绍


大批量训练是加速大型分布式系统中深度神经网络训练的关键。尤其是在如今的大模型时代,如果不采用大批量训练,一个网络的训练时间几乎是难以接受的。但是,大批量训练很难,因为它会产生泛化差距(generalization gap), 直接训练会导致其准确率降低。此前的大批量工作往往针对于图像分类以及一些自然语言处理的任务,但密集预测任务(包括检测分割等),同样在视觉中处于举足轻重的位置,此前的方法并不能在密集预测任务上有很好的表现,甚至结果比基准线更差,这导致我们难以快速训练一个目标检测器。

 

为了解决这个问题,研究人员进行了大量的实验。最后发现,相较于传统的分类网络,利用密集预测网络一个很重要的特征:密集预测网络往往是由多个组件组成的,以 Faster R-CNN 为例:它由四个部分组成,骨干网络 (Backbone),特征金字塔网络(FPN),区域生成网络(RPN) 和检测头网络(head),我们可以发现一个很有效的指标:密集预测网络不同组件的梯度方差,在训练批量很小时(例如 32),几乎是相同的,但当训练批量很大时(例如 512),它们呈现出很大的区别,如下图所示:


那么,能不能直接把这些拉平呢?这直接引出了 AGVM 算法。以随机梯度下降算法为例,上角标 i 代表第 i 个网络模块(例如 FPN 等),上角标 1 代表骨干网络,代表学习率,锚定骨干网络,可以直接将不同网络组件的梯度 g 的方差



梯度的方差可以由以下式子估计:


方差的具体求解细节可以参考原文,本研究同样引入了滑动平均机制,防止网络训练发散。同时,研究证明了 AGVM 在非凸情况下的收敛性,讨论了动量以及衰减的处理方式,具体实现细节可以参考原文。

 

实验过程


本研究首先在目标检测、实例分割、全景分割和语义分割的各种密集预测网络上进行了测试,通过下表可以看到,当用标准批量大小训练时,AGVM 相较传统方法没有明显优势,但当在超大批量下训练时,AGVM 相较传统方法拥有压倒性的优势,下图第二列从左至右分别表示目标检测,实例分割,全景分割和语义分割的表现,AGVM 超越了有史以来的所有方法:


下表详细对比了 AGVM 和传统方法,体现出了本研究方法的优势:


同时,为了说明 AGVM 的优越性,本研究进行了以下三个超大规模的实验。研究人员把 Faster R-CNN 的 batch size 放到了 1536,这样利用 768 张 A100 可以在 4.2 分钟内完成训练。其次,借助 UniNet-G,本研究可以在利用 480 张 A100 的情况下,3.5 个小时让模型在 COCO 上达到 62.2mAP(不包括骨干网络预训练的时间),极大的减小了训练时间:


甚至,在 RetinaNet 上,本研究把批量大小扩展到 10K。这在目标检测领域是从未见的批量大小,在如此大的批量下,每一个 epoch 只有十几个迭代次数,AGVM 在如此大的批量下,仍然能展现出很强的稳定性,性能如下图所示:


结果分析


本研究探究了一个很重要的问题:以 RetinaNet 为例,如下图第一列所示,探究为什么会出现梯度方差不匹配这一现象。


本研究认为,这一现象来自于:网络不同模块间的有效批量大小 (effective batch size) 是不同的。例如,RetinaNet 的头网络的输入是由特征金字塔的五层网络输出的,特征金字塔的 top-down 和 bottom-up pathways,以及像素维度的损失函数计算会导致头网络和骨干网络的等效批量大小不同,这一原理导致了梯度方差不匹配的现象。


为了验证这一假设,本研究依次给每一层特征使用单独的头网络,移去特征金字塔网络,随机忽略掉 75% 的用于计算损失函数的像素,最终,本研究发现骨干网络和头网络的梯度方差曲线重合了,本研究也对 Faster R-CNN 做了类似的实验,如下图第二列所示,更多的讨论请参见原文。



相关文章
|
关系型数据库 MySQL 数据库
MySQL数据库加密和解密~认证登陆密码(mysql.user)和MySQL不区分大小写
MySQL数据库认证密码有两种方式: 1:MySQL 4.1版本之前是MySQL323加密 2:MySQL 4.1和之后的版本都是MySQLSHA1加密 还有函数:AES_ENCRYPT()加密函数和AES_DECRYPT()解密函数和MD5()加密。 MySQL数据库中自带old_password(str)和password(str)函数,前者是MySQL323加密,后者是MySQ
6237 0
|
安全 JavaScript 前端开发
区块链钱包系统开发解决方案/需求设计/功能逻辑/案例详细/源码步骤
The development of a blockchain wallet system involves multiple aspects, and the following is the detailed logic for developing a blockchain wallet system:
|
Java 数据库连接 mybatis
mybatis返回map类型数据空值字段不显示(三种解决方法)
mybatis返回map类型数据空值字段不显示(三种解决方法)
|
10月前
|
机器学习/深度学习 数据可视化 算法
RT-DETR改进目录一览 | 涉及卷积层、轻量化、注意力、损失函数、Backbone、SPPF、Neck、检测头等全方位改进
RT-DETR改进目录一览 | 涉及卷积层、轻量化、注意力、损失函数、Backbone、SPPF、Neck、检测头等全方位改进
704 5
|
Java
Error:java: 无效的目标发行版: 11解决方案
Error:java: 无效的目标发行版: 11解决方案
643 1
ConnectionResetError: [Errno 104] Connection reset by peer|4-16
ConnectionResetError: [Errno 104] Connection reset by peer|4-16
|
数据可视化 Ubuntu Linux
PyCharm连接远程服务器配置的全过程
相信很多人都遇见过这种情况:实验室成员使用同一台服务器,每个人拥有自己的独立账号,我们可以使用服务器更好的配置完成实验,毕竟自己哪有money拥有自己的3090呢。 通常服务器系统采用Linux,而我们平常使用频繁的是Windows系统,二者在操作方面存在很大的区别,比如我们实验室的服务器采用Ubuntu系统,创建远程交互任务时可以使用Terminal终端或者VNC桌面化操作,我觉得VNC很麻烦,所以采用Terminal进行实验,但是Terminal操作给我最不好的体验就是无法可视化中间实验结果,而且实验前后的数据上传和下载工作也让我头疼不已。
|
数据采集 数据挖掘 Serverless
利用Python和Pandas库优化数据清洗流程
在数据分析项目中,数据清洗是至关重要的一步。传统的数据清洗方法往往繁琐且易出错。本文将介绍如何利用Python编程语言中的Pandas库,通过其强大的数据处理能力,实现高效、自动化的数据清洗流程。我们将探讨Pandas库在数据清洗中的应用,包括缺失值处理、重复值识别、数据类型转换等,并通过一个实际案例展示如何利用Pandas优化数据清洗流程,提升数据质量。
|
安全
qt.qpa.xcb: could not connect to display 问题解决
【5月更文挑战第16天】qt.qpa.xcb: could not connect to display qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found. This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem. 问题解决
7516 1