TensorFlow 2 Keras实现线性回归

简介: TensorFlow 2 Keras实现线性回归

介绍


线性回归是入门机器学习必学的算法,其也是最基础的算法之一。


接下来,我们以线性回归为例,使用 TensorFlow 2 提供的 API 和 Eager Execution 机制对其进行实现。


线性回归是一种较为简单,但十分重要的机器学习方法,它也是神经网络的基础。


如下所示,线性回归要解决的问题就是如何找到最理想的直线去拟合散点样本。


13.png


对于一个线性回归问题,一般来讲有 2 种解决方法,分别是:


  • 最小二乘法

代数求解

矩阵求解

  • 梯度下降法。

本次,我们将使用梯度下降方法来解决线性回归问题。


Keras 方式实现


配合 TensorFlow 提供的高阶 API,我们省去了定义线性函数,定义损失函数,以及定义优化算法等 3 个步骤。


不过,高阶 API 实现过程实际上还不够精简,我们可以完全使用 TensorFlow Keras API 来实现线性回归。


Keras 本来是一个用 Python 编写的独立高阶神经网络 API,它能够以 TensorFlow, CNTK,或者 Theano 作为后端运行。


目前,TensorFlow 已经吸纳 Keras,并组成了 tf.keras 模块。官方介绍,tf.keras 和单独安装的 Keras 略有不同,但考虑到未来的发展趋势,主要以学习 tf.keras 为主。


12.png


初始化


11.png


我们这里使用 Keras 提供的 Sequential 顺序模型结构。向其中添加一个线性层。不同的地方在于,Keras 顺序模型第一层为线性层时,规定需指定输入维度,这里为 input_dim=1。


10.png


接下来,直接使用 .compile 编译模型,指定损失函数为 MSE 平方损失函数,优化器选择 SGD 随机梯度下降。然后,就可以使用 .fit 传入数据开始迭代了。


batch_size 是采用小批次训练的参数,主要用于解决一次性传入数据过多无法训练的问题。当然,由于示例数据本身较少,这里意义不大,但还是按照常规使用方法进行设置。


你会发现,完全使用 Keras 高阶 API 实际上只需要 4 行核心代码即可完成,相比于低阶 API 简化了很多。


完整代码:

import tensorflow as tf
TRUE_W = 3.0
TRUE_b = 2.0
NUM_SAMPLES = 100
X = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y = X * TRUE_W + TRUE_b + noise
# 模型训练
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1,input_dim=1))
model.compile(optimizer='sgd',loss='mse')
model.fit(X,y,epochs=10,batch_size=32)
目录
相关文章
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
113 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
3月前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
3月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
142 0
|
5月前
|
机器学习/深度学习 IDE API
【Tensorflow+keras】Keras 用Class类封装的模型如何调试call子函数的模型内部变量
该文章介绍了一种调试Keras中自定义Layer类的call方法的方法,通过直接调用call方法并传递输入参数来进行调试。
48 4
|
5月前
|
Apache 开发者 Java
Apache Wicket揭秘:如何巧妙利用模型与表单机制,实现Web应用高效开发?
【8月更文挑战第31天】本文深入探讨了Apache Wicket的模型与表单处理机制。Wicket作为一个组件化的Java Web框架,提供了多种模型实现,如CompoundPropertyModel等,充当组件与数据间的桥梁。文章通过示例介绍了模型创建及使用方法,并详细讲解了表单组件、提交处理及验证机制,帮助开发者更好地理解如何利用Wicket构建高效、易维护的Web应用程序。
64 0
|
5月前
|
机器学习/深度学习 API TensorFlow
深入解析TensorFlow 2.x中的Keras API:快速搭建深度学习模型的实战指南
【8月更文挑战第31天】本文通过搭建手写数字识别模型的实例,详细介绍了如何利用TensorFlow 2.x中的Keras API简化深度学习模型构建流程。从环境搭建到数据准备,再到模型训练与评估,展示了Keras API的强大功能与易用性,适合初学者快速上手。通过简单的代码,即可完成卷积神经网络的构建与训练,显著降低了深度学习的技术门槛。无论是新手还是专业人士,都能从中受益,高效实现模型开发。
55 0
|
5月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】用Tensorflow.keras的方法替代keras.layers.merge
在TensorFlow 2.0和Keras中替代旧版keras.layers.merge函数的方法,使用了新的层如add, multiply, concatenate, average, 和 dot来实现常见的层合并操作。
43 1
|
5月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】学习率指数、分段、逆时间、多项式衰减及自定义学习率衰减的完整实例
使用Tensorflow和Keras实现学习率衰减的完整实例,包括指数衰减、分段常数衰减、多项式衰减、逆时间衰减以及如何通过callbacks自定义学习率衰减策略。
87 0
|
5月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
42 0
|
5月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
72 0