你知道数据模型需要多少训练数据吗?

简介:

毫无疑问机器学习是大数据分析不可或缺的一部分,在使用机器学习技术的时候工程师除了要选择合适的算法之外还需要选择合适的样本数据。那么工程师到底应该选择哪些样本数据、选择多少样本数据才最合适呢?来自于Google的软件工程师Malay Haldar最近发表了一篇题为《数据模型需要多少训练数据》的文章对此进行了介绍。

训练数据的质量和数量通常是决定一个模型性能的最关键因素。一旦训练数据准备好,其他的事情就顺理成章了。但是到底应该准备多少训练数据呢?答案是 这取决于要执行的任务,要满足的性能,所拥有的输入特征、训练数据中的噪音、提取特征中的噪音以及模型的复杂程度等因素。而找出这些变量之间相互关系的方 法就是在不同数据量的训练数据上训练模型并绘制学习曲线。但是这仅仅适合于已经有一定数量的训练数据的情况,如果是最开始的时候,或者说只有很少一点训练 数据的情况,那应该怎么办呢?

与死板地给出所谓精确的“正确”答案相比,更靠谱的方法是通过估算和具体的经验法则。例如本文将要介绍的实证方法:首先自动生成很多逻辑回归问题。 然后对生成的每一个问题,研究训练数据的数量与训练模型的性能之间的关系。最后通过观察这两者在这一系列问题上的关系总结出一个简单的规则。

生成一系列逻辑回归问题并研究不同数据量的训练数据所造成的影响的代码可以从GitHub上获取。相关代码是基于Tensorflow实现的,运行这些代码不需要任何特殊的软件或者硬件,用户可以在自己的笔记本上运行整个实验。代码运行之后生成的图表如下:

逻辑回归

其 中,X轴是训练样本的数量与模型参数数量的比率。Y轴是训练模型的得分(f-score)。不同颜色的曲线表示不同参数数量的模型。例如,红色曲线代表模 型有128个参数,曲线的轨迹表明了随着训练样本从128 x 1到 128 x 2并不断增长的过程中该模型的得分变化。

通过该图表,我们能够发现模型得分并不会随着参数规模的变化而变化。但是这是针对线性模型而言,对于一些隐藏的非线性模型并不适合。当然,更大的模 型需要更多的训练数据,但是对于一个给定的训练模型数量与模型参数数量比率其性能是一样的。该图表还显示,当训练样本的数量与模型参数数量的比率达到 10:1之后,模型得分基本稳定在0.85,该比率便可以作为良好性能模型的一种定义。根据该图表我们可以总结出10X规则,也就是说一个优秀的性能模型 需要训练数据的数量10倍于该模型中参数的数量。

10X规则将估计训练数据数量的问题转换成了需要知道模型参数数量的问题。对于逻辑回归这样的线性模型,参数的数量与输入特征的数量相等,因为模型会为每一个特征分派一个相关的参数。但是这样做可能会有一些问题:

特征可能是稀疏的,因而可能会无法直接计算出特征的数量。

由于正则化和特征选择技术,很多特征可能会被抛弃,因而与原始的特征数相比,真正输入到模型中的特征数会非常少。

避免这些问题的一种方法是:必须认识到估算特征的数量时并不是必须使用标记的数据,通过未标记的样本数据也能够实现目标。例如,对于一个给定的大文 本语料库,可以在标记数据进行训练之前通过生成单词频率的历史图表来理解特征空间,通过历史图表废弃长尾单词进而估计真正的特征数,然后应用10X规则来 估算模型需要的训练数据的数据量。

需要注意的是,神经网络构成的问题集与逻辑回归这样的线性模型并不相同。为了估算神经网络所需要的参数数量,你需要:

如果输入是稀疏的,那么需要计算嵌套层使用的参数的数量。参照word2vec的Tensorflow教程示例。

计算神经网络中边的数量

由于神经网络中参数之间的关系并不是线性的,所以本文基于逻辑回归所做的实证研究并不适合神经网络。但是在这种情况下,可以将10X规则作为训练数据所需数据量的下限。

尽管有上面的问题,根据Malay Haldar的经验,10X规则对于大部分问题还是适用的,包括浅神经网络。如果有疑问,可以在Tensorflow的代码中插入自己的模型和假设,然后运行代码进行验证研究。


本文作者:Malay Haldar

来源:51CTO

相关文章
|
机器学习/深度学习 存储 算法
时序数据特征工程浅析
内容摘要特征工程是指将原始数据标记处理为价值密度更高,更容易解释目标问题的工程化过程,在面向大量原始采集的数据集统计分析,尤其是对于高通量持续采集、且价值密度较低的时序数据更是如此。时序数据特征工程则是指利用有效方法,将原始时序数据转化为带有含义分类标签的序列数据片段或特征数值,例如,我们可以将指定时间窗口序列数据标识为特定异常关联数据,并保留平均、最大、最小值作为该序列的特征值。这样我们就可以围
3329 0
时序数据特征工程浅析
|
存储 数据采集 分布式计算
一篇文章搞懂数据仓库:四种常见数据模型(维度模型、范式模型等)
一篇文章搞懂数据仓库:四种常见数据模型(维度模型、范式模型等)
一篇文章搞懂数据仓库:四种常见数据模型(维度模型、范式模型等)
|
2月前
|
机器学习/深度学习 自然语言处理 C++
TSMamba:基于Mamba架构的高效时间序列预测基础模型
TSMamba通过其创新的架构设计和训练策略,成功解决了传统时间序列预测模型面临的多个关键问题。
185 4
TSMamba:基于Mamba架构的高效时间序列预测基础模型
|
5月前
|
计算机视觉
利用各类回归模型,对数据集进行建模
【8月更文挑战第8天】利用各类回归模型,对数据集进行建模。
52 4
|
8月前
|
机器学习/深度学习 数据采集 分布式计算
【机器学习】Spark ML 对数据进行规范化预处理 StandardScaler 与向量拆分
标准化Scaler是数据预处理技术,用于将特征值映射到均值0、方差1的标准正态分布,以消除不同尺度特征的影响,提升模型稳定性和精度。Spark ML中的StandardScaler实现此功能,通过`.setInputCol`、`.setOutputCol`等方法配置并应用到DataFrame数据。示例展示了如何在Spark中使用StandardScaler进行数据规范化,包括创建SparkSession,构建DataFrame,使用VectorAssembler和StandardScaler,以及将向量拆分为列。规范化有助于降低特征重要性,提高模型训练速度和计算效率。
164 6
|
7月前
|
机器学习/深度学习 分布式计算 监控
在大数据模型训练中,关键步骤包括数据收集与清洗、特征工程、数据划分;准备分布式计算资源
【6月更文挑战第28天】在大数据模型训练中,关键步骤包括数据收集与清洗、特征工程、数据划分;准备分布式计算资源,选择并配置模型如深度学习架构;通过初始化、训练、验证进行模型优化;监控性能并管理资源;最后保存模型并部署为服务。过程中要兼顾数据隐私、安全及法规遵守,利用先进技术提升效率。
124 0
|
7月前
|
机器学习/深度学习 数据采集 人工智能
特征工程对ML/DL至关重要,涉及数据清洗、转换和特征选择,以提升模型预测和泛化能力。
【6月更文挑战第28天】特征工程对ML/DL至关重要,涉及数据清洗、转换和特征选择,以提升模型预测和泛化能力。它改善数据质量,浓缩关键信息,优化性能,增强解释性。特征选择,如过滤法、RFE、嵌入式和包裹式方法,是关键步骤,常需迭代和结合业务知识。自动化工具如AutoML简化了这一过程。
65 0
|
数据采集 机器学习/深度学习 数据处理
类别数据处理:你必须知道的技巧与方法
类别数据处理:你必须知道的技巧与方法
175 0
|
机器学习/深度学习 数据采集 算法
UCI数据集详解及其数据处理(附148个数据集及处理代码)
UCI数据集详解及其数据处理(附148个数据集及处理代码)
3447 1
|
数据可视化 Python
使用PyMC进行时间序列分层建模
在统计建模领域,理解总体趋势的同时解释群体差异的一个强大方法是分层(或多层)建模。这种方法允许参数随组而变化,并捕获组内和组间的变化。在时间序列数据中,这些特定于组的参数可以表示不同组随时间的不同模式。
144 0

相关实验场景

更多