迁移学习中如何利用权值调整数据分布?DATL、L2TL两大方法解析

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文综述了两篇在迁移学习中利用权值调整数据分布的论文。通过这两个重要工作,读者可了解如何在迁移学习中进行微调的方法和理论。

深度神经网络的应用显著改善了各种数据挖掘和计算机视觉算法的性能,因此广泛应用于各类机器学习场景中。然而,深度神经网络方法依赖于大量的标记数据来训练深度学习模型,在实际应用中,获取足够的标记数据往往既昂贵又耗时。因此,一个自然的想法是利用现有数据集(即源域)中丰富的标记样本,辅助在要学习的数据集(即目标域)中的学习。解决这类跨领域学习问题的一种有效方法就是迁移学习:首先在一个大的标记源数据集(如 ImageNet)上训练模型,然后在目标数据集上进行模型调整更新,从而实现将已训练好的模型参数迁移到新的模型来帮助新模型训练。


基于深度神经网络的迁移学习主要有三种方式:一是迁移学习(Transfer Learning),重新训练全连接层,其他预训练模型的卷积层不变;二是特征向量提取(Feature Vector Extraction),利用预训练模型的卷积层提取源和目标数据集的特征向量,之后训练目标域中的全连接网络;三是微调(Fine-tune),重新学习分类层的参数,而其余网络层参数则沿用预训练模型的初始化值。


研究人员发现,仅靠改进迁移学习的方式(如上述三种迁移学习方式)并不能进一步降低目标域中模型的损失值,而选择改进用作模型预训练的源数据集的丰富程度则是一种有效的方法。源数据集的丰富程度并不仅由数据集中数据量的大小决定,而同时取决于用于预训练的数据集是否能够有效捕获到与目标域中数据集相似的差异性特征(因素)。前期的方法主要是通过不同的度量方法找到源数据集与目标数据集中的相似样本数据,例如 [1] 使用滤波器组响应中的特征来选择源数据集中的最近邻样本,与使用整个源数据集相比,该方法具备更好的性能。[2] 利用土方运距(Earth Mover』s Distance,EMD)对源数据集和目标数据集之间的区域相似性进行量化计算,之后利用一个简单的贪婪子集生成选择准则提高目标测试集的性能。然而上述方法只是找到相似的样本数据,无法有效捕获目标数据集中的变化判别因素,因此迁移学习的效果改进有限。这种微调相当于对迁移学习的前两种步骤的改进,进一步提升了迁移学习的性能,因此本文探讨的是改善微调方式的迁移学习。


Ngiam et al. 提出了一种利用权值捕获源域和目标域中相似信息从而有效调整数据分布的方法,即基于目标数据集的重要权值域自适应迁移学习方法(Domain Adaptive Transfer Learning,DATL)[1]。DATL 利用概率形态识别源数据集中能够有效捕获目标数据集中变化判别因素的样本数据,使用 JFT 和 ImageNet 数据集作为源训练数据,并考虑一系列用于微调的目标数据集。在微调过程中,对网络中的分类层进行随机初始化训练。在这项工作的基础上,Zhu et al. 提出了共享权值的概念,即对源和目标任务模型之间共享权值联合优化的学习框架(Learning to Transfer Learn,L2TL)[2],其中关于共享权值的计算是利用基于目标数据集的性能度量矩阵的强化学习模块(RL)实现的,从而保证自适应输出每个源数据集中类别的权值。L2TL 基于目标数据集中的测试性能自适应的推断域相似度。本文对 DATL 和 L2TL 进行详细的分析,目的是探讨在迁移学习中利用权值调整数据分布的有效性,以及计算权值的不同方式对迁移学习效果、计算成本等的影响。


1、Domain Adaptive Transfer Learning with Specialist Models


原文地址:https://arxiv.org/pdf/1811.07056.pdf


方法分析


DATL 使用 JFT 和 ImageNet 数据集作为源预训练数据,不在源数据集和目标数据集之间执行任何标签对齐处理。而是利用数据集之间的标签产生的权值进行调整。在微调过程中,对神经网络中的分类层进行随机初始化训练。首先考虑一个简化的设置,即源数据集和目标数据集位于相同的像素 x 和标签 y 值集上。预训练阶段,在源域中优化参数θ以最小化损失函数:


(1)


微信图片_20211202065008.jpg


其中 Ds 表示源数据集,L(f_θ(x),y) 为模型 f_θ(y) 的预测与标签真值 y 之间的交叉熵损失函数。源数据集 Ds 中的数据分布与目标数据集 Dt 中的分布可能不同,通过加大与目标数据集最相关的样本的权值来解决这种问题。目标数据集 Dt 中的损失函数为:


(2)

微信图片_20211202065101.jpg其中 Ps、Pt 分别表示源和目标数据集的概率分布。结合以上两个公式,重新计算(2)包含源数据集 Ds 的损失函数如下:


(3)


微信图片_20211202065124.jpg


接下来,假设 Ps(x|y) 约等于 Pt(x|y),即在源数据集中给定特定标签的样本分布与目标数据集的近似分布是相同的,(3)可简化为:


微信图片_20211202065205.jpg


其中 Pt(y)/Ps(y) 为我们需要的权值。


为了使 DATL 在实践中适用,需要对简化设置(即源数据集和目标数据集共享相同的标签空间)进行放松假设,放松假设的处理过程具体为:「在真实的应用场景中,源数据集和目标数据集一般具有不同的标签集,解决方案是 Pt(y) 和 Ps(y) 的估计都在源域中进行,而不再基于目标域估计 Pt(y)。通过将标签出现的次数除以源数据集的样本总数计算分母 Ps(y)。为了估计 Pt(y),则使用一个分类器来计算来自源数据集的标签在来自目标数据集的样本上的概率。」


完整的 DATL 方法示例见图 1。为了计算重要性权值 Pt(y)/Ps(y),首先使用在整个 JFT 数据集上预训练的图像模型来评估来自目标数据集的图像。对于每一幅图像,能够得到其对 JFT 中 18291 个类的预测。对这些预测进行平均化处理后得到 Pt(y)。通过将标签在源预训练数据集中出现的次数除以源预训练数据集中的样本总数,直接从源预训练数据集中估计 Ps(y)。因此,权值 Pt(y)/Ps(y) 表示源预训练数据集中给定标签的重要程度。使用这些重要性权值在整个 JFT 数据集上训练生成预训练模型,然后在目标数据集上进行微调。


微信图片_20211202065228.jpg图 1. DATL 方法完整过程


实验分析


本文实验中通过使用重要性权值从源数据集(JFT 和 ImageNet)中采样样本来创建预训练数据集。预训练阶段使用 Inception v3 和 AmoebaNet-B 神经网络模型,微调阶段使用随机初始化的分类层来代替预训练的分类层。利用 SGD 对模型进行 20000 步的训练,每个小批量包含 256 个样本。使用保持验证集(hold-out validation set)计算权值正则化和学习速度参数。


微信图片_20211202065302.jpg表 1. 使用 Inception v3 的迁移学习结果


表 1 给出使用 Inception v3 的迁移学习结果,每一行对应一个预训练方法,其中 Adaptive Transfer 指的是本文提出的方法。每列对应一个目标数据集。表 1 中结果是除 Oxford-IIIT Pets 外的所有数据集的最高准确度,我们给出了每个类的平均准确度。所有结果均执行 5 次微调处理。由表 1 结果可知,当源数据集与目标数据集完全匹配时,迁移学习效果最优;当源域和目标域不匹配时出现了负迁移。值得注意的是,在预训练阶段使用更多的源数据反而会影响迁移学习的效果。在所有类别情况下,在整个 JFT 数据集上预训练的模型效果都差于在某些具体子集上预训练的模型效果。此外,使用本文提出的 DATL 方法甚至比手动选择标签效果更好。


微信图片_20211202065329.jpg表 2. 使用 AmoebaNet-B 的迁移学习结果


表 2 给出了使用 AmoebaNet-B 的迁移学习结果,实验目的是验证较大模型是否能够更好的捕获更多的变化因素。AmoebaNet-B 上的实验参数超过 5.5 亿。另外,表 2 中的实验结果(AmoebaNet-B)优于表 1 的结果(Inception v3)。说明使用较大的模型能够缩小一般子集和特定子集之间的性能差距。


微信图片_20211202065359.jpg图 2. 使用 ImageNet 作为源预训练数据集时,每个目标数据集的重要性权值分布


最后,图 2 给出了使用 ImageNet 作为源预训练数据集时,每个目标数据集的重要性权值分布。由图 2 可知目标数据集之间的分布差异很大。FGVC Aircraft 只选择了一些粗粒度的标签,而 Oxford Pets 则选择了更广泛的细粒度标签,这反映了 ImageNet 数据集中固有的偏差。


总结


本文提出的 DATL 方法能够有效识别源预训练数据集中包含类别判别信息的数据样本,当未能有效捕获判别信息时迁移学习的效果就会受到影响。此外,本文实验还证明当使用较大的神经网络模型时,在类别子集中预训练的迁移学习效果更好。也就是说,如果是在完整的源数据集中完成预训练,则训练过程还需额外处理细粒度类别间的区别。


相关文章
|
16天前
|
人工智能
歌词结构的巧妙安排:写歌词的方法与技巧解析,妙笔生词AI智能写歌词软件
歌词创作是一门艺术,关键在于巧妙的结构安排。开头需迅速吸引听众,主体部分要坚实且富有逻辑,结尾则应留下深刻印象。《妙笔生词智能写歌词软件》提供多种 AI 功能,帮助创作者找到灵感,优化歌词结构,写出打动人心的作品。
|
22天前
|
存储 算法 Java
解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用
在Java中,Set接口以其独特的“无重复”特性脱颖而出。本文通过解析HashSet的工作原理,揭示Set如何利用哈希算法和equals()方法确保元素唯一性,并通过示例代码展示了其“无重复”特性的具体应用。
38 3
|
17天前
|
人工智能
写歌词的技巧和方法全解析:开启你的音乐创作之旅,妙笔生词智能写歌词软件
怀揣音乐梦想,渴望用歌词抒发情感?掌握关键技巧,你也能踏上创作之旅。灵感来自生活点滴,主题明确,语言简洁,韵律和谐。借助“妙笔生词智能写歌词软件”,AI辅助创作,轻松写出动人歌词,实现音乐梦想。
|
1天前
|
JSON PHP 数据格式
PHP解析配置文件的常用方法
INI文件是最常见的配置文件格式之一。
|
8天前
|
机器学习/深度学习 人工智能 安全
TPAMI:安全强化学习方法、理论与应用综述,慕工大、同济、伯克利等深度解析
【10月更文挑战第27天】强化学习(RL)在实际应用中展现出巨大潜力,但其安全性问题日益凸显。为此,安全强化学习(SRL)应运而生。近日,来自慕尼黑工业大学、同济大学和加州大学伯克利分校的研究人员在《IEEE模式分析与机器智能汇刊》上发表了一篇综述论文,系统介绍了SRL的方法、理论和应用。SRL主要面临安全性定义模糊、探索与利用平衡以及鲁棒性与可靠性等挑战。研究人员提出了基于约束、基于风险和基于监督学习等多种方法来应对这些挑战。
21 2
|
16天前
|
安全 Java
Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧
【10月更文挑战第20天】Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧,包括避免在循环外调用wait()、优先使用notifyAll()、确保线程安全及处理InterruptedException等,帮助读者更好地掌握这些方法的应用。
13 1
|
23天前
|
存储 JavaScript 前端开发
Vue3权限控制全攻略:路由与组件层面的用户角色与权限管理方法深度解析
Vue3权限控制全攻略:路由与组件层面的用户角色与权限管理方法深度解析
93 2
|
23天前
|
SQL 监控 数据库
SQL语句是否都需要解析及其相关技巧和方法
在数据库管理中,SQL(结构化查询语言)语句的使用无处不在,它们负责数据的查询、插入、更新和删除等操作
|
23天前
|
SQL 数据可视化 BI
SQL语句及查询结果解析:技巧与方法
在数据库管理和数据分析中,SQL语句扮演着至关重要的角色
|
24天前
|
JavaScript
深入解析:JS与Vue中事件委托(事件代理)的高效实现方法
深入解析:JS与Vue中事件委托(事件代理)的高效实现方法
27 0

推荐镜像

更多
下一篇
无影云桌面