谷歌发布最新元学习框架「DVRL」,用强化学习量化模型中每个数据点的价值

简介: Google AI研究院最近的研究表明,并不是所有数据样本对于训练都同样有用,特别是对于深度神经网络(DNN)来说。昨日,他们发表了一篇博客,详细叙述了用强化学习评估训练数据的影响。

微信图片_20220109174728.png


你是否还在使用大规模数据集进行无脑训练呢?

实际上,如果数据集包含低质量或标签不正确的数据,通常可以通过删除大量的训练样本来提高性能。

 

此外,如果训练集与测试集不匹配(例如,由于训练和测试位置或时间的差异) ,人们还可以通过将训练集中的样本限制为与测试场景最相关的样本,从而获得更高的性能。

 

由于这些场景的普遍存在,准确量化训练样本的值对于提高真实数据集上的模型性能具有很大的潜力。


        微信图片_20220109174730.png    

 

除了提高模型性能之外,为单个数据分配质量值(quality value)也可以启用新的用例,也可以用来提出更好的数据收集方法。

 

例如,什么类型的附加数据最有利,并可用于更有效地构建大规模的训练数据集,或者使用标签作为关键字进行网络搜索,过滤掉不太有价值的数据。

 

量化数据的价值

 

对于给定的机器学习模型,并不是所有的数据都是相等的。一些数据与手头的任务有更大的相关性,或者相比其他数据有更丰富的信息内容。

 

那么,到底该如何评估单一数据的价值呢?在完整数据集的粒度上,人们可以简单地在整个数据集上训练一个模型,并将其在测试集上的性能作为数据的价值。

 

然而估计单一数据的价值要困难得多,特别是对于依赖于大规模数据集的复杂模型,因为在计算复杂度上来说,不可能对一个模型的所有可能的子集进行重新训练和评估。

 

为了解决这个问题,研究人员探索了基于排列的方法(例如:influence functions)和基于博弈论的方法(例如:data Shapley)。

 

微信图片_20220109174732.jpg


然而,即使是当前最好的方法也远不能适用于大型数据集和复杂模型,而且它们的数据评估性能也是有限的。

 

同时,基于元学习(meta learning)的自适应权重分配方法已经被开发出来,用来使用元目标(meta-objective)估计权重值。


但是他们并没有优先考虑从高价值的数据样本中学习,而通常是基于梯度下降法学习或者其他启发式方法得到数据价值的映射。这些方法改变了传统的预测模型的动态训练,会导致与单个数据点的价值无关的性能变化。

 

使用强化学习评估数据(DVRL)

 

为了推断数据值,我们提出了一种数据值估计器(DVE) ,该估计器用来估计数据值,并选择最有价值的样本来训练预测器模型。

 

这种操作基本上是不可微的,因此不能使用传统的基于梯度下降的方法。

 

相反,Google的研究员们建议使用强化学习(RL) ,这样 DVE 的监督是基于一个奖励Reward,而这个Reward能用来量化预测器在一个很小但干净的验证集上面的性能。

 

DVRL:Data Valuation Using Reinforcement Learning


在给定状态和输入样本的情况下,Reward指导Policy进行最优化选择,向着最优的数据价值方向进行。


        微信图片_20220109174734.gif


Google AI 研究院以预测模型学习和评估框架为环境,提出了一种新的基于实例推理的机器学习应用方案。      


微信图片_20220109174735.gif


图:使用DVRL进行训练。在用准确的Reward训练DVE时,最有价值的样本(用绿点表示)被使用得越来越多,而最无价值的样本(红点)被使用得越来越少。

 

实验结果

 

结果评估了 DVRL 在不同类型数据集和用例上的数据价值估计的质量。

 

1.去除高/低值样本后的模型性能:

 

从训练集中剔除低值样本可以提高预测器模型的性能,特别是在训练集中含有损坏样本的情况下。

 

另一方面,移除高值的样本,特别是当数据集很小时,会显著降低性能。

 

总体而言,剔除高/低值样本后的表现是数据评估质量的一个强有力的指标

    微信图片_20220109174737.png      

2.带有噪声标签的鲁棒学习:

 

Google AI的研究人员考虑使 DVRL 在带有噪声标签时可以在端到端的方式中学习,而不必删除低价值的样本。

 

理想情况下,噪声样本应该得到低数据值,因为 DVRL 会收敛的同时将返回一个高性能模型。

        微信图片_20220109174739.png


图:数据集的标签上有40% 的均匀随机噪声,DVRL 优于其他流行的基于元学习的方法


结果显示,在最小化噪声标签影响的情况下,DVRL取得了SOTA的结果。这也表明了DVRL可以应用到复杂模型和大规模数据集。

 

3.领域适应(Domain adaptation):

 

Google考虑的场景是,训练集来自与验证和测试集完全不同的分布。通过从训练数据集中选择最适合验证数据集分布的样本,数据估值预计将对此任务有所帮助。

    微信图片_20220109174740.png      

 DVRL 通过联合优化数据估值器和相应的预测器模型,显著提高了领域的适应性。

 

结论


Google AI研究院这次提出了一种新的元学习数据评估框架,该框架决定了每个训练样本用在预测模型的训练过程的可能性。

 

与以往的研究不同的是,该方法将数据评估融入到预测器模型的训练过程中,使得预测器和DVE能够相互提高。

 

通过使用一个经过 RL 训练的 DNN 对这个数据值估计任务进行建模,并从一个代表目标任务绩效的小验证集中获得奖励。

 

DVRL 以高效的计算方法提供了高质量的排序后的训练数据,有利于领域自适应、错误样本发现和鲁棒学习,同时还发现了 DVRL 在不同类型的任务和数据集上显著优于其他方法。

 

 

参考链接:

https://ai.googleblog.com/2020/10/estimating-impact-of-training-data-with.html

相关文章
|
开发工具 iOS开发 MacOS
解决VScode文件无法编辑(删除键 换行键失去作用)
解决VScode文件无法编辑(删除键 换行键失去作用)
3882 0
|
3月前
|
数据采集 人工智能 监控
大模型微调数据质量评估指南:如何为你的AI挑选“好食材”
本文系统介绍大模型微调数据质量的科学评估框架,提出“复杂性、可用性、多样性”三大核心维度,并结合推理损失逆向验证,提供可落地的五步评估法与实操工具(如LLaMA-Factory Online),助力团队以更少高质量数据获得更优模型效果。
|
9月前
|
域名解析 JSON API
【干货满满】如何处理requests库调用API接口时的异常情况
在调用 API 时,网络波动、服务器错误、参数异常等情况难以避免。本文提供一套系统化的异常处理方案,涵盖 requests 库常见异常类型、处理策略、实战代码与最佳实践,通过分类处理、重试机制与兜底策略,提升接口调用的稳定性与可靠性。
|
资源调度 JavaScript 前端开发
前端开发必备!Node.js 18.x LTS保姆级安装教程(附国内镜像源配置)
本文详细介绍了Node.js的安装与配置流程,涵盖环境准备、版本选择(推荐LTS版v18.x)、安装步骤(路径设置、组件选择)、环境验证(命令测试、镜像加速)及常见问题解决方法。同时推荐开发工具链,如VS Code、Yarn等,并提供常用全局包安装指南,帮助开发者快速搭建高效稳定的JavaScript开发环境。内容基于官方正版软件,确保合规性与安全性。
14512 23
|
存储 Java 关系型数据库
ssm150旅游网站的设计与实现+jsp(文档+源码)_kaic
本旅游网站基于现代经济快节奏发展和信息化技术的升级,采用SSM框架、Java语言及Mysql数据库开发。它实现了景点、新闻、酒店、飞机票和火车票管理等功能,帮助管理者高效处理大量数据信息,提升工作效率。系统界面简洁美观,功能布局合理,同时提供了数据安全解决方案,确保信息的安全性和可靠性。该网站不仅提高了事务处理效率,还实现了数据的整体化、规范化与自动化管理。关键词:旅游网站;SSM框架;Mysql;自动化。
|
数据采集 机器学习/深度学习 人工智能
[大语言模型-论文精读] 利用多样性进行大型语言模型预训练中重要数据的选择
[大语言模型-论文精读] 利用多样性进行大型语言模型预训练中重要数据的选择
|
JavaScript 前端开发 API
JavaScript循环遍历常用的7种方法以及常用的数组 API
JavaScript循环遍历常用的7种方法以及常用的数组 API
506 0
|
存储 算法 数据格式
一篇文章讲明白Mipmap与纹理过滤
一篇文章讲明白Mipmap与纹理过滤
698 1
|
JavaScript 调度
Three.js开发秘籍:FlyControls的拖拽视角问题解决方案
Three.js开发秘籍:FlyControls的拖拽视角问题解决方案
306 0
|
机器学习/深度学习 并行计算 PyTorch
【已解决】RuntimeError: CUDA error: device-side assert triggeredCUDA kernel errors might be asynchronous
【已解决】RuntimeError: CUDA error: device-side assert triggeredCUDA kernel errors might be asynchronous
11692 2

热门文章

最新文章