从深度学习到深度森林方法(Python)

简介: 本文第一节源于周志华教授《关于深度学习的一点思考》, 在此基础上对深度学习、深度森林做了原理解析并实践。

一、关于深度学习的一点思考



二、深度森林的介绍


目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV)、自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular data的表现也是稍弱的),而在其他涉及符号建模、离散建 模、混合建模的任务上,深度神经网络的性能并没有那么好。


深度森林(gcForest)是深度神经网络(DNN)之外的探索的一种深度模型,原文:it may open a door towards alternative to deep neural networks for many tasks。不同于深度神经网络由可微的神经元组成,深度森林的基础构件是不可微的决策树,其训练过程不基于 BP 算 法,甚至不依赖于梯度计算。它初步验证了关于深度学习奏效原因的猜想--即只要能做到逐层加工处理、内置特征变换、模型复杂度够,就能构建出有效的深度学习模型,并非必须使用神经网络。


深度森林主要的特点是:


  • 拥有比其他基于决策树的集成学习方法更好的性能


  • 拥有更少的超参数,并且无需大量的调参


  • 训练效率高,并且能够处理大规模的数据集


深度森林目前还处于探索阶段,评估模型(gcForest)的表现,在MNIST数据集准确率不错:



在CIFAR-10数据集上准确率欠佳:



三、深度森林原理


深度森林其实也就是ensemble of ensemble的模型,可以看作是对集成树(森林)模型,进一步stacking集成学习及优化(Complete Random Forest、shortcut-connection、Multi-Grained Scanning等),以解决深层容易过拟合的问题。


3.1 特征的处理


深度森林借鉴了CNN滑动卷积核的特征提取,通过多粒度扫描(Multi-Grained Scanning)方法,滑动窗口扫描原始特征,生成输入特征。



如上图,假设我们现在有一个400维(序列数据)的样本输入,现在设定采样窗口是100维的,那我们可以通过逐步的采样,最终获得301个子样本(默认采样步长是1,得到的子样本个数 = (400-100)/1 + 1)。



整个特征处理的过程就是:先输入一个完整的P维样本,然后通过一个长度为k的采样窗口进行滑动采样,得到S = (P - K)/1+1 个k维特征子样本向量,接着每个子样本都用于完全随机森林和普通随机森林的训练并在每个森林都获得一个长度为C(类别数)的概率向量,这样每个森林会产生长度为S*C的表征向量(即经过随机森林转换并拼接的概率向量),最后把每层的F个森林的结果拼接在一起得到本层输出。


3.2 深度森林的架构



整个模型的架构采用一个级联结构(如上图), 每一级都采用两种森林构建: random forests (black) and completely-random tree forests(blue),使用completely-random可以增加基模型的多样性,以减少过拟合风险,提高集成学习的效果。


完全随机森林(completely-random tree forests):由多棵树组成,每棵树包含所有的特征,并且随机选择一个特征作为分裂树的分裂节点。一直分裂到每个叶子节点只包含一个类别或者不多于是个样本结束。随机森林(random forests):同样由多棵树构成,每棵树通过随机选取Sqrt(特征总数)个特征 ,然后通过GINI分数来筛选分裂节点。


每一层的级联接收前一层处理的特征信息,并将处理结果(类向量)输出到下一层(如上图)。以三分类为例,输入特征为向量x,经过每个森林学习后(注:每个森林的学习的数据利用k折交叉验证得到,以减少过拟合风险),得到预测类分布,然后求平均,再与之前原始特征拼接(类似shortcut-connection),作为下一层的输入。


扩展完一层后,整个级联结构可在验证集上面测试性能,若没有显著提高,训练过程会终止,故而层数可以自动确定。这也是gcForest能够自动决定模型复杂度的原因。


四、深度森林预测


本节简单使用深度森林模型用于波士顿房价回归预测及癌细胞分类任务。 安装:pip install deep-forest


  • 波士顿房价回归预测,使用默认参数效果还不错:Testing MSE: 8.068


# 回归预测--波士顿房价
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from deepforest import CascadeForestRegressor
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
model = CascadeForestRegressor(random_state=1)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print("\nTesting MSE: {:.3f}".format(mse))


  • 癌细胞分类任务,准确率不错,Testing Accuracy: 95.105 %


# 分类预测--癌细胞分类数据集
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from deepforest import CascadeForestClassifier
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
model = CascadeForestClassifier(random_state=1)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred) * 100
print("\nTesting Accuracy: {:.3f} %".format(acc))
相关文章
|
24天前
|
JSON 数据可视化 API
Python 中调用 DeepSeek-R1 API的方法介绍,图文教程
本教程详细介绍了如何使用 Python 调用 DeepSeek 的 R1 大模型 API,适合编程新手。首先登录 DeepSeek 控制台获取 API Key,安装 Python 和 requests 库后,编写基础调用代码并运行。文末包含常见问题解答和更简单的可视化调用方法,建议收藏备用。 原文链接:[如何使用 Python 调用 DeepSeek-R1 API?](https://apifox.com/apiskills/how-to-call-the-deepseek-r1-api-using-python/)
|
3月前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Python实现深度学习模型的分布式训练
使用Python实现深度学习模型的分布式训练
204 73
|
2月前
|
人工智能 自然语言处理 算法
随机的暴力美学蒙特卡洛方法 | python小知识
蒙特卡洛方法是一种基于随机采样的计算算法,广泛应用于物理学、金融、工程等领域。它通过重复随机采样来解决复杂问题,尤其适用于难以用解析方法求解的情况。该方法起源于二战期间的曼哈顿计划,由斯坦尼斯拉夫·乌拉姆等人提出。核心思想是通过大量随机样本来近似真实结果,如估算π值的经典示例。蒙特卡洛树搜索(MCTS)是其高级应用,常用于游戏AI和决策优化。Python中可通过简单代码实现蒙特卡洛方法,展示其在文本生成等领域的潜力。随着计算能力提升,蒙特卡洛方法的应用范围不断扩大,成为处理不确定性和复杂系统的重要工具。
77 21
|
2月前
|
数据挖掘 数据处理 开发者
Python3 自定义排序详解:方法与示例
Python的排序功能强大且灵活,主要通过`sorted()`函数和列表的`sort()`方法实现。两者均支持`key`参数自定义排序规则。本文详细介绍了基础排序、按字符串长度或元组元素排序、降序排序、多条件排序及使用`lambda`表达式和`functools.cmp_to_key`进行复杂排序。通过示例展示了如何对简单数据类型、字典、类对象及复杂数据结构(如列车信息)进行排序。掌握这些技巧可以显著提升数据处理能力,为编程提供更强大的支持。
42 10
|
12天前
|
SQL 关系型数据库 MySQL
Python中使用MySQL模糊查询的方法
本文介绍了两种使用Python进行MySQL模糊查询的方法:一是使用`pymysql`库,二是使用`mysql-connector-python`库。通过这两种方法,可以连接MySQL数据库并执行模糊查询。具体步骤包括安装库、配置数据库连接参数、编写SQL查询语句以及处理查询结果。文中详细展示了代码示例,并提供了注意事项,如替换数据库连接信息、正确使用通配符和关闭数据库连接等。确保在实际应用中注意SQL注入风险,使用参数化查询以保障安全性。
|
3月前
|
机器学习/深度学习 数据采集 供应链
使用Python实现智能食品消费需求分析的深度学习模型
使用Python实现智能食品消费需求分析的深度学习模型
103 21
|
3月前
|
机器学习/深度学习 数据采集 搜索推荐
使用Python实现智能食品消费偏好预测的深度学习模型
使用Python实现智能食品消费偏好预测的深度学习模型
133 23
|
3月前
|
机器学习/深度学习 数据采集 数据挖掘
使用Python实现智能食品消费习惯预测的深度学习模型
使用Python实现智能食品消费习惯预测的深度学习模型
177 19
|
3月前
|
安全
Python-打印99乘法表的两种方法
本文详细介绍了两种实现99乘法表的方法:使用`while`循环和`for`循环。每种方法都包括了步骤解析、代码演示及优缺点分析。文章旨在帮助编程初学者理解和掌握循环结构的应用,内容通俗易懂,适合编程新手阅读。博主表示欢迎读者反馈,共同进步。
|
3月前
|
机器学习/深度学习 数据采集 数据挖掘
使用Python实现智能食品消费模式预测的深度学习模型
使用Python实现智能食品消费模式预测的深度学习模型
93 2

热门文章

最新文章