图文详解神秘的梯度下降算法原理(附Python代码)

简介: 图文详解神秘的梯度下降算法原理(附Python代码)

目录

1 引例

给定如图所示的某个函数,如何通过计算机算法编程求 f ( x ) m i n f(x)_{min} f(x)

min



image.png

2 数值解法

传统方法是数值解法,如图所示


image.png

按照以下步骤迭代循环直至最优:


① 任意给定一个初值 x 0 x_0 x

0



② 随机生成增量方向,结合步长生成 Δ x \varDelta x Δx;


③ 计算比较 f ( x 0 ) f\left( x_0 \right) f(x

0


)与 f ( x 0 + Δ x ) f\left( x_0+\varDelta x \right) f(x

0


+Δx)的大小,若 f ( x 0 + Δ x ) < f ( x 0 ) f\left( x_0+\varDelta x \right) <f\left( x_0 \right) f(x

0


+Δx)<f(x

0


)则更新位置,否则重新生成 Δ x \varDelta x Δx;


④ 重复②③直至收敛到最优 f ( x ) m i n f(x)_{min} f(x)

min



数值解法最大的优点是编程简明,但缺陷也很明显:


① 初值的设定对结果收敛快慢影响很大;


② 增量方向随机生成,效率较低;


③ 容易陷入局部最优解;


④ 无法处理“高原”类型函数。


所谓陷入局部最优解是指当迭代进入到某个极小值或其邻域时,由于步长选择不恰当,无论正方向还是负方向,学习效果都不如当前,导致无法向全局最优迭代。就本问题而言如图所示,当迭代陷入 x = x j x=x_j x=x

j


时,由于学习步长 s t e p step step的限制,无法使 f ( x j ± S t e p ) < f ( x j ) f\left( x_j\pm Step \right) <f(x_j) f(x

j


±Step)<f(x

j


),因此迭代就被锁死在了图中的红色区段。可以看出 x = x j x=x_j x=x

j


并非期望的全局最优。


image.png

若出现下图所示的“高原”函数,也可能使迭代得不到更新。


image.png

3 梯度下降算法

梯度下降算法可视为数值解法的一种改进,阐述如下:


记第 k k k轮迭代后,自变量更新为 x = x k x=x_k x=x

k


,令目标函数 f ( x ) f(x) f(x)在 x = x k x=x_k x=x

k


泰勒展开:


f ( x ) = f ( x k ) + f ′ ( x k ) ( x − x k ) + o ( x ) f\left( x \right) =f\left( x_k \right) +f'\left( x_k \right) \left( x-x_k \right) +o(x)

f(x)=f(x

k


)+f

(x

k


)(x−x

k


)+o(x)


考察 f ( x ) m i n f(x)_{min} f(x)

min


,则期望 f ( x k + 1 ) < f ( x k ) f\left( x_{k+1} \right) <f\left( x_k \right) f(x

k+1


)<f(x

k


),从而:


f ( x k + 1 ) − f ( x k ) = f ′ ( x k ) ( x k + 1 − x k ) < 0 f\left( x_{k+1} \right) -f\left( x_k \right) =f'\left( x_k \right) \left( x_{k+1}-x_k \right) <0

f(x

k+1


)−f(x

k


)=f

(x

k


)(x

k+1


−x

k


)<0


若 f ′ ( x k ) > 0 f'\left( x_k \right) >0 f

(x

k


)>0则 x k + 1 < x k x_{k+1}<x_k x

k+1


<x

k


,即迭代方向为负;反之为正。不妨设 x k + 1 − x k = − f ′ ( x k ) x_{k+1}-x_k=-f'(x_k) x

k+1


−x

k


=−f

(x

k


),从而保证 f ( x k + 1 ) − f ( x k ) < 0 f\left( x_{k+1} \right) -f\left( x_k \right) <0 f(x

k+1


)−f(x

k


)<0。必须指出,泰勒公式成立的条件是 x → x 0 x\rightarrow x_0 x→x

0


,故 ∣ f ′ ( x k ) ∣ |f'\left( x_k \right) | ∣f

(x

k


)∣不能太大,否则 x k + 1 x_{k+1} x

k+1


与 x k x_{k} x

k


距离太远产生余项误差。因此引入学习率 γ ∈ ( 0 , 1 ) \gamma \in \left( 0, 1 \right) γ∈(0,1)来减小偏移度,即 x k + 1 − x k = − γ f ′ ( x k ) x_{k+1}-x_k=-\gamma f'(x_k) x

k+1


−x

k


=−γf

(x

k


)


在工程上,学习率 γ \gamma γ要结合实际应用合理选择, γ \gamma γ过大会使迭代在极小值两侧振荡,算法无法收敛; γ \gamma γ过小会使学习效率下降,算法收敛慢。


对于向量 ,将上述迭代公式推广为


x k + 1 = x k − γ ∇ x k {\boldsymbol{x}_{\boldsymbol{k}+1}=\boldsymbol{x}_{\boldsymbol{k}}-\gamma \nabla _{\boldsymbol{x}_{\boldsymbol{k}}}}

x

k+1


=x

k


−γ∇

x

k




其中 ∇ x = ( ∂ f ( x ) ∂ x 1 , ∂ f ( x ) ∂ x 2 , ⋯ ⋯   , ∂ f ( x ) ∂ x n ) T \nabla _{\boldsymbol{x}}=\left( \frac{\partial f(\boldsymbol{x})}{\partial x_1},\frac{\partial f(\boldsymbol{x})}{\partial x_2},\cdots \cdots ,\frac{\partial f(\boldsymbol{x})}{\partial x_n} \right) ^T ∇

x


=(

∂x

1


∂f(x)


,

∂x

2


∂f(x)


,⋯⋯,

∂x

n


∂f(x)


)

T

为多元函数的梯度,故此迭代算法也称为梯度下降算法


image.png

梯度下降算法通过函数梯度确定了每一次迭代的方向和步长,提高了算法效率。但从原理上可以知道,此算法并不能解决数值解法中初值设定、局部最优陷落和部分函数锁死的问题。

4 代码实战:Logistic回归

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit
'''
* @breif: 从CSV中加载指定数据
* @param[in]: file -> 文件名
* @param[in]: colName -> 要加载的列名
* @param[in]: mode -> 加载模式, set: 列名与该列数据组成的字典, df: df类型
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):
    assert mode in ('set', 'df')
    df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
    if mode == 'df':
        return df
    if mode == 'set':
        res = {}
        for col in colName:
            res[col] = df[col].values
        return res
if __name__ == '__main__':
    # ============================
    # 读取CSV数据
    # ============================
    csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
    dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
    dataY = loadCsvData(csvPath, ["好瓜"], 'df')
    label = np.array([
        1 if i == "是" else 0
        for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
    ])
    # ============================
    # 绘制样本点
    # ============================
    line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    plt.title('对数几率回归模拟\nLogistic Regression Simulation')
    plt.xlabel('density')
    plt.ylabel('sugarRate')
    plt.scatter(dataX['密度'][label==0],
                dataX['含糖率'][label==0],
                marker='^',
                color='k',
                s=100,
                label='坏瓜')
    plt.scatter(dataX['密度'][label==1],
                dataX['含糖率'][label==1],
                marker='^',
                color='r',
                s=100,
                label='好瓜')
    # ============================
    # 实例化对数几率回归模型
    # ============================
    logit = Logit(dataX, label)
    # 采用梯度下降法
    logit.logitRegression(logit.gradientDescent)
    line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
    plt.plot(line_x, line_y, 'b-', label="梯度下降法")
    # 绘图
    plt.legend(loc='upper left')
    plt.show()

image.png

🔥 更多精彩专栏

目录
相关文章
|
16天前
|
监控 算法 安全
深度洞察内网监控电脑:基于Python的流量分析算法
在当今数字化环境中,内网监控电脑作为“守城卫士”,通过流量分析算法确保内网安全、稳定运行。基于Python的流量分析算法,利用`scapy`等工具捕获和解析数据包,提取关键信息,区分正常与异常流量。结合机器学习和可视化技术,进一步提升内网监控的精准性和效率,助力企业防范潜在威胁,保障业务顺畅。本文深入探讨了Python在内网监控中的应用,展示了其实战代码及未来发展方向。
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
眼疾识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了4种常见的眼疾图像数据集(白内障、糖尿病性视网膜病变、青光眼和正常眼睛) 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,实现用户上传一张眼疾图片识别其名称。
105 5
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
|
12天前
|
存储 缓存 Java
Python高性能编程:五种核心优化技术的原理与Python代码
Python在高性能应用场景中常因执行速度不及C、C++等编译型语言而受质疑,但通过合理利用标准库的优化特性,如`__slots__`机制、列表推导式、`@lru_cache`装饰器和生成器等,可以显著提升代码效率。本文详细介绍了这些实用的性能优化技术,帮助开发者在不牺牲代码质量的前提下提高程序性能。实验数据表明,这些优化方法能在内存使用和计算效率方面带来显著改进,适用于大规模数据处理、递归计算等场景。
49 5
Python高性能编程:五种核心优化技术的原理与Python代码
|
14天前
|
存储 算法 安全
控制局域网上网软件之 Python 字典树算法解析
控制局域网上网软件在现代网络管理中至关重要,用于控制设备的上网行为和访问权限。本文聚焦于字典树(Trie Tree)算法的应用,详细阐述其原理、优势及实现。通过字典树,软件能高效进行关键词匹配和过滤,提升系统性能。文中还提供了Python代码示例,展示了字典树在网址过滤和关键词屏蔽中的具体应用,为局域网的安全和管理提供有力支持。
47 17
|
23天前
|
存储 监控 算法
员工电脑监控屏幕场景下 Python 哈希表算法的探索
在数字化办公时代,员工电脑监控屏幕是保障信息安全和提升效率的重要手段。本文探讨哈希表算法在该场景中的应用,通过Python代码例程展示如何使用哈希表存储和查询员工操作记录,并结合数据库实现数据持久化,助力企业打造高效、安全的办公环境。哈希表在快速检索员工信息、优化系统性能方面发挥关键作用,为企业管理提供有力支持。
44 20
|
18天前
|
存储 人工智能 算法
深度解密:员工飞单需要什么证据之Python算法洞察
员工飞单是企业运营中的隐性风险,严重侵蚀公司利润。为应对这一问题,精准搜集证据至关重要。本文探讨如何利用Python编程语言及其数据结构和算法,高效取证。通过创建Transaction类存储交易数据,使用列表管理订单信息,结合排序算法和正则表达式分析交易时间和聊天记录,帮助企业识别潜在的飞单行为。Python的强大功能使得从交易流水和沟通记录中提取关键证据变得更加系统化和高效,为企业维权提供有力支持。
|
2月前
|
Python
课程设计项目之基于Python实现围棋游戏代码
游戏进去默认为九路玩法,当然也可以选择十三路或是十九路玩法 使用pycharam打开项目,pip安装模块并引用,然后运行即可, 代码每行都有详细的注释,可以做课程设计或者毕业设计项目参考
72 33
|
17天前
|
存储 算法 安全
U 盘管控情境下 Python 二叉搜索树算法的深度剖析与探究
在信息技术高度发达的今天,数据安全至关重要。U盘作为常用的数据存储与传输工具,其管控尤为关键。本文探讨Python中的二叉搜索树算法在U盘管控中的应用,通过高效管理授权U盘信息,防止数据泄露,保障信息安全。二叉搜索树具有快速插入和查找的优势,适用于大量授权U盘的管理。尽管存在一些局限性,如树结构退化问题,但通过优化和改进,如采用自平衡树,可以有效提升U盘管控系统的性能和安全性。
21 3
|
1月前
|
存储 算法 Serverless
剖析文件共享工具背后的Python哈希表算法奥秘
在数字化时代,文件共享工具不可或缺。哈希表算法通过将文件名或哈希值映射到存储位置,实现快速检索与高效管理。Python中的哈希表可用于创建简易文件索引,支持快速插入和查找文件路径。哈希表不仅提升了文件定位速度,还优化了存储管理和多节点数据一致性,确保文件共享工具高效运行,满足多用户并发需求,推动文件共享领域向更高效、便捷的方向发展。
|
2月前
|
JavaScript API C#
【Azure Developer】Python代码调用Graph API将外部用户添加到组,结果无效,也无错误信息
根据Graph API文档,在单个请求中将多个成员添加到组时,Python代码示例中的`members@odata.bind`被错误写为`members@odata_bind`,导致用户未成功添加。
49 10

热门文章

最新文章