基于PyTorch,集合17种方法,南京大学等提出小样本算法库LibFewShot

简介: 近日,南京大学推理与学习研究组(Reasoning and Learning Research Group, R&L Group)联合澳大利亚伍伦贡大学、美国罗彻斯特大学开源了一个小样本学习算法库 LibFewShot。该库包含了 17 个 2017 年到 2020 年具有代表性的小样本学习算法,为小样本学习领域中算法对比采用统一框架、统一设置、实现公平对比等提供便利。关于 LibFewShot 的文章已经发布在 arXiv。

近日,南京大学推理与学习研究组(Reasoning and Learning Research Group, R&L Group)联合澳大利亚伍伦贡大学、美国罗彻斯特大学开源了一个小样本学习算法库 LibFewShot。该库包含了 17 个 2017 年到 2020 年具有代表性的小样本学习算法,为小样本学习领域中算法对比采用统一框架、统一设置、实现公平对比等提供便利。关于 LibFewShot 的文章已经发布在 arXiv。


微信图片_20211206125805.jpg




近年来,小样本学习,特别是小样本图像分类问题引起了学界越来越多的关注。最近的一些研究表明,许多通用的技术或者技巧,如数据增强、预训练、知识蒸馏和自监督可能会极大地提高小样本学习方法的性能。此外,不同的算法可能由于使用不同的深度学习框架,不同的训练方式,不同的嵌入网络,甚至不同的输入图像大小,使得对这些算法进行公平对比变得困难,而且初学者也往往难以复现相关算法。


为了解决这些问题,我们提出了一个统一的小样本学习算法库 LibFewShot,该库基于 PyTorch 深度学习框架,统一重新实现了 17 种小样本学习方法。此外,基于 LibFewShot 提供了基于四层卷积、ResNet12、ResNet18 这三种 backbone 在 miniImageNet、tieredImageNet 两个数据集上的评估结果以及在 CUB-Birds、StanfordDog、StanfordCar 三个数据集上的跨域结果,用以综合评估不同算法的效果。


另外,鉴于一些最新工作开始重新思考 meta 或 episodic 训练机制的必要性,我们也在这个方向上,基于 LibFewshot 框架进行了一些探索和研究,并从实验结果中发现 meta 或 episodic 训练机制还是有效的,特别是在与预训练相结合时,这样的训练机制仍然能够显著提升模型的表现。我们希望 LibFewShot 不仅能够降低小样本学习的使用门槛,还能够消除一些常用深度学习技巧的影响,以促进小样本学习领域的发展。


小样本总览


小样本任务通常包含两部分数据,一部分是用来学习的有标签的支撑集(support set),另一部分是待分类的无标签的查询集(query set)。为了获得对每个任务快速学习的能力,通常还有一个大的辅助集(auxiliary set),通常支撑集和查询集的实际类别是一致的,而辅助集的类别和它们是不相交的。小样本学习中 「小」 的概念来自于支撑集,支撑集有 C 类图像,每类图像有 K 张,称为 C-way K-shot 小样本问题,C 通常取 5 或 10,K 通常取 1 或者 5。小样本学习任务的重点是如何通过在辅助集上的进行学习,使得在面对新的任务时,仅仅通过支撑集的少量样本,就能够完成对查询集的识别和分类。


根据在辅助集上以及支撑集上训练策略的不同,将小样本学习分为三类,分别是基于微调的方法、基于元学习的方法、基于度量的方法,图 1 中给出了三类方法的代表性框架图。


  • 基于微调的方法:基于微调的方法和迁移学习有着相似的过程,一般可分为使用辅助集的预训练阶段和使用支撑集的微调阶段。代表方法有 Baseline[1],Baseline++[1],RFS[2],SKD[3]等;


  • 基于元学习的方法:基于元学习的方法在训练阶段采用元训练的方式来在辅助集上进行训练,通常采用二阶段的优化,一个阶段是支撑集更新基学习器,另一阶段用查询集更新元学习器、适应新的任务。代表方法有 MAML[4],R2D2[5]等;


  • 基于度量的方法:基于度量的方法通常在辅助集上的训练采用的是 episodic training 的方式,即在这个阶段从辅助集中采样相似的小样本学习任务,使用大量相似任务来训练网络,使得网络能够学习到快速适应新任务的能力。代表方法有 ProtoNet[6],RelationNet[7],DN4[8]等。


微信图片_20211206130403.jpg

图 1. 小样本学习方法分类,(a) 基于微调的方法;(b) 基于元学习的方法;(c) 基于度量的方法.


复现代码


基于以上三个种类的分类,LibFewShot 实现了 17 个具有代表性的算法,从上到下依次是基于微调的方法、基于元学习的方法和基于度量的方法,复现结果如下:


image.jpeg


此外,由于各个方法所使用的技巧、网络结构等存在一定的差别。为了公平比较不同方法在各个数据集上的表现,LibFewShot 统一了一些变量,在 miniImageNet 和 tieredImageNet 上做了如下实验:


微信图片_20211206131610.jpg


从表中可以看出,RFS 和 SKD 这样的基于微调的方法,表现要明显优于其他的方法,但是也用了更多的技巧,例如,预训练(pre-train)、自监督(self-supervised,SS)和知识蒸馏(knowledge distillation,KD)都能够明显提高模型的表现。当基于元学习的方法和基于度量学习的方法在使用了预训练等技巧后,如 MTL 和 CAN,表现也接近于 RFS 和 SKD。由此可见,预训练加上 episodic training 的微调是一种很有潜力的训练方法。


另外,论文中也比较了这些方法在 ResNet12 网络结构下的跨域表现,结果如下表所示:


image.jpeg


从表中可以看出,从 miniImageNet 和 tieredImageNet 数据集泛化到 StanfordDog 和 CUB-Birds 数据集是比较简单的,因为它们都是相似的自然图像数据集。而当迁移到 StanfordCar 数据集上,所有方法的表现都会急剧下降,这表明,当前 SOTA 的方法并不能很好处理域偏移较大的跨域情况。


讨论


通过对比以上实验结果,我们提出了两个问题:(1)在训练阶段,使用支撑集进行 episodic training 是否真的不重要,(2)在测试阶段,使用测试任务的支撑集进行微调是否真的对小样本测试很有效。为了回答这两个问题,我们选取了具有代表性的 RFS、ProtoNet 以及 MTL 做了对比实验。对于问题 1,我们选取了 ProtoNet 和 MTL 两个方法,这两个方法的结果从上述表格中结果来看都是不如 RFS 的。但是当使用了 RFS 的预训练模型,并在此基础上使用 episodic training 微调预训练网络之后,从图 2 柱状图可以看出,两种方法的结果都要比 RFS 高上许多。因此,我们认为训练阶段的 episodic training 还是有必要的,特别是在结合预训练的情况下,能够更进一步提升模型的性能。对于问题 2,我们发现如果使用和 RFS 相同预训练的网络,在测试时直接采用使用欧氏距离的 ProtoNet,RFS 的效果的确是更好一些的。但是当 ProtoNet 使用了 L2 归一化后,即改成使用余弦距离后,结果反而要比 RFS 高,因此得出一个有意思的结论:在测试阶段进行微调可能并没有那么重要,相反 L2 归一化可能起了更加重要的作用。


微信图片_20211206131804.jpg

图 2. ProtoNet-cosine 和 MTL-protonet 在预训练模型基础上进行微调的结果


LibFewShot 是一个集成了多个代表性小样本学习方法的统一框架,为小样本学习领域的方法进行公平的实验可以带来巨大的便利。同时,我们也对小样本学习中预训练和 episodic training 的作用进行了深入的思考,肯定了预训练的价值,也证明了 episodic training 的必要性,同时也强调了 L2 归一化在小样本学习中的作用。


参考文献:[1] Chen, Wei-Yu, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, and Jia-Bin Huang. "A Closer Look at Few-shot Classification." In International Conference on Learning Representations. 2018.[2] Tian, Yonglong, Yue Wang, Dilip Krishnan, Joshua B. Tenenbaum, and Phillip Isola. "Rethinking few-shot image classification: a good embedding is all you need?" In European Conference on Computer Vision 2020, Part XIV 16, pp. 266-282. 2020.[3] Rajasegaran, Jathushan, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Mubarak Shah. "Self-supervised knowledge distillation for few-shot learning." arXiv preprint arXiv:2006.09785. 2020.[4] Finn, Chelsea, Pieter Abbeel, and Sergey Levine. "Model-agnostic meta-learning for fast adaptation of deep networks." In International Conference on Machine Learning, pp. 1126-1135. PMLR, 2017.[5] Bertinetto, Luca, Joao F. Henriques, Philip Torr, and Andrea Vedaldi. "Meta-learning with differentiable closed-form solvers." In International Conference on Learning Representations. 2018.[6] Snell, Jake, Kevin Swersky, and Richard Zemel. "Prototypical Networks for Few-shot Learning." Advances in Neural Information Processing Systems 30 (2017): 4077-4087. 2017.[7] Sung, Flood, Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M. Hospedales. "Learning to compare: Relation network for few-shot learning." In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1199-1208. 2018.[8] Li, Wenbin, Lei Wang, Jinglin Xu, Jing Huo, Yang Gao, and Jiebo Luo. "Revisiting local descriptor based image-to-class measure for few-shot learning." In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7260-7268. 2019.




相关文章
|
8天前
|
算法 安全 数据安全/隐私保护
Crypto++库支持多种加密算法
【10月更文挑战第29天】Crypto++库支持多种加密算法
32 4
|
23天前
|
存储 算法 Java
解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用
在Java中,Set接口以其独特的“无重复”特性脱颖而出。本文通过解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用。
38 3
|
27天前
|
算法 索引
HashMap扩容时的rehash方法中(e.hash & oldCap) == 0算法推导
HashMap在扩容时,会创建一个新数组,并将旧数组中的数据迁移过去。通过(e.hash & oldCap)是否等于0,数据被巧妙地分为两类:一类保持原有索引位置,另一类索引位置增加旧数组长度。此过程确保了数据均匀分布,提高了查询效率。
37 2
|
30天前
|
机器学习/深度学习 算法 PyTorch
Pytorch-RMSprop算法解析
关注B站【肆十二】,观看更多实战教学视频。本期介绍深度学习中的RMSprop优化算法,通过调整每个参数的学习率来优化模型训练。示例代码使用PyTorch实现,详细解析了RMSprop的参数及其作用。适合初学者了解和实践。
35 1
|
1月前
|
搜索推荐 Shell
解析排序算法:十大排序方法的工作原理与性能比较
解析排序算法:十大排序方法的工作原理与性能比较
47 9
|
30天前
|
存储 算法 Java
数据结构与算法学习八:前缀(波兰)表达式、中缀表达式、后缀(逆波兰)表达式的学习,中缀转后缀的两个方法,逆波兰计算器的实现
前缀(波兰)表达式、中缀表达式和后缀(逆波兰)表达式的基本概念、计算机求值方法,以及如何将中缀表达式转换为后缀表达式,并提供了相应的Java代码实现和测试结果。
31 0
数据结构与算法学习八:前缀(波兰)表达式、中缀表达式、后缀(逆波兰)表达式的学习,中缀转后缀的两个方法,逆波兰计算器的实现
|
1月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
2月前
|
存储 算法 安全
超级好用的C++实用库之sha256算法
超级好用的C++实用库之sha256算法
81 1
|
30天前
|
机器学习/深度学习 算法 PyTorch
Pytorch-SGD算法解析
SGD(随机梯度下降)是机器学习中常用的优化算法,特别适用于大数据集和在线学习。与批量梯度下降不同,SGD每次仅使用一个样本来更新模型参数,提高了训练效率。本文介绍了SGD的基本步骤、Python实现及PyTorch中的应用示例。
30 0
|
30天前
|
机器学习/深度学习 传感器 算法
Pytorch-Adam算法解析
肆十二在B站分享深度学习实战教程,本期讲解Adam优化算法。Adam结合了AdaGrad和RMSProp的优点,通过一阶和二阶矩估计,实现自适应学习率,适用于大规模数据和非稳态目标。PyTorch中使用`torch.optim.Adam`轻松配置优化器。
38 0
下一篇
无影云桌面