EM 算法求解高斯混合模型python实现

简介: 注:本文是对《统计学习方法》EM算法的一个简单总结。1. 什么是EM算法?  引用书上的话:概率模型有时既含有观测变量,又含有隐变量或者潜在变量。如果概率模型的变量都是观测变量,可以直接使用极大似然估计法或者贝叶斯的方法进行估计模型参数,但是当模型含有隐藏变量时,就不能简单使用这些方法了。

注:本文是对《统计学习方法》EM算法的一个简单总结。

1. 什么是EM算法?

  引用书上的话:

概率模型有时既含有观测变量,又含有隐变量或者潜在变量。如果概率模型的变量都是观测变量,可以直接使用极大似然估计法或者贝叶斯的方法进行估计模型参数,但是当模型含有隐藏变量时,就不能简单使用这些方法了。EM算法就是含有隐变量的概率模型参数的极大似然估计法,或者极大似然后验概率估计法。

2. EM 算法的一个小例子:三硬币模型

  假设有3枚硬币,记作A,B,C。这些硬币的正面出现的概率分别为\(\pi\)\(p\)\(q\)。进行如下的试验:先掷硬币A,根据A的结果选择B和C,如果掷A得到正面,则选择B;如果掷A得到反面,则选择C。接着掷出选出的硬币。记录下这次掷硬币的结果,如果是正面,则记作1,反面则记作0。独立重复做了n次试验(这里取n=10),得到结果如下:1,1,0,1,0,0,1,0,1,1。假设只能观测到抛硬币的结果,不能观测到抛硬币的过程,那么我们该如何估计三硬币的参数\(\pi\)\(p\)\(q\)呢?(也就是估计三枚硬币正面向上的概率)
  EM算法分为E步和M步。EM 算法首先选取了参数的初始值,记作\(\theta^{(0)}\)=(\(\pi^{(0)}\),\(p^{(0)}\),\(q^{(0)}\))。然后通过下面的步骤迭代计算参数的估计值,直到收敛为止,第\(i\)次迭代的参数的估计值记作\(\theta^{(i)}\)=(\(\pi^{(i)}\),\(p^{(i)}\),\(q^{(i)}\)),则EM算法的第\(i+1\)次迭代为:
  E步:计算模型在参数\(\pi^{(i)}\)\(p^{(i)}\)\(q^{(i)}\)下观测数据\(y_j\)来自掷硬币B的概率为
\(\mu_j^{(i+1)}\) = \(\frac{\pi^{(i)}(p^{(i)})^{y_j}(1-p^{(i)})^{1-y_j}}{\pi^{(i)}(p^{(i)})^{y_j}(1-p^{(i)})^{1-y_j} + (1-\pi^{(i)})(q^{(i)})^{y_j}(1-q^{(i)})^{1-y_j}}\)
  M步:计算模型新的参数的估计值:
\(\pi^{(i+1)}=\frac{1}{n}\sum_{j=1}^{n}\mu_j^{(i+1)}\)
\(p^{(i+1)}=\frac{\sum_{j=1}^{n}\mu_j^{(i+1)}y_j}{\sum_{j=1}^{n}\mu_j^{(i+1)}}\)
\(q^{(i+1)}=\frac{\sum_{j=1}^{n}(1-\mu_j^{(i+1)})y_j}{\sum_{j=1}^{n}(1-\mu_j^{(i+1)})}\)

下面带入具体的数字计算一下。如果模型的初始参数取值为:\(\pi{(0)}=0.5,p^{(0)}=0.5,q^{(0)}=0.5\)
那么带入上面的三个公式就可以计算出:\(\pi{(1)}=0.5,p^{(1)}=0.6,q^{(1)}=0.6\),继续迭代可以得到 \(\pi{(2)}=0.5,p^{(2)}=0.6,q^{(2)}=0.6\)。于是我们可以得到原来参数\(\theta\)的极大似然估计为:\(\hat{\pi}=0.5,\hat{p}=0.6,\hat{q}=0.6\)
如果取初值改为 \(\pi{(0)}=0.4,p^{(0)}=0.6,q^{(0)}=0.7\),那么我们可以计算出:\(\mu_1^{(1)}=\frac{0.4\times0.6}{0.4\times0.6+0.6\times0.7}=0.364\)
:\(\mu_0^{(1)}=\frac{0.4\times0.4}{0.4\times0.4+0.6\times0.3}=0.47\)
那么由于观测值中1出现了6次,0出现了4次,于是我们容易计算出:
\(\pi^{(1)}=\frac{0.364\times6+0.47\times4}{10}=0.4064\)
\(p^{(1)}=\frac{0.364\times1\times6+0.47\times0\times4}{0.364\times6+0.47\times4}=0.537\)
\(q^{(1)}=\frac{(1-0.364)\times6}{(1-0.364)\times6+(1-0.47)\times4}=0.643\)
显然,我们由于初始值选择的不同,导致了模型参数最终估计的不同,这说明了EM算法对于参数的初始值是非常敏感的。

3 EM 算法的步骤

输入:观测变量数据\(Y\),隐变量数据\(Z\),联合分布\(P(Y,Z|\theta)\),条件分布\(P(Z|Y,\theta)\);
输出:模型参数\(\theta\)

(1) 选择模型参数的初值\(\theta^{(0)}\),开始迭代;
(2)\(E\)步:记\(\theta^{(i)}\)为第\(i\)次迭代参数\(\theta\)的估计值,在第\(i+1\)次迭代的E步,计算
\(Q(\theta,\theta^{(i)}=E_Z[logP(Y,Z|\theta)|Y,\theta^{(i)}]=\sum_{Z}logP(Y,Z|\theta)P(Z|Y,\theta^{(i)})\)
这里\(P(Z|Y,\theta^{(i)})\)是在给定观测数据Y和当前参数估计的\(\theta^{(i)}\)下隐变量数据\(Z\)的条件概率分布;
(3)\(M\)步:求使得\(Q(\theta,\theta^{(i)})\)极大化的\(\theta\),确定第\(i+1\)次迭代的参数的估计值\(\theta^{(i+1)}\)
\(\theta^{(i+1)}=arg max_{\theta}Q(\theta,\theta^{(i)})\)
(4)重复第(2)步和第(3)步,直到算法收敛。
上面定义的函数\(Q(\theta,\theta^{(i)})\)是EM算法的核心,称为Q函数(Q function)。下面给出Q函数的定义:
(Q函数) 完全数据的对数似然函数\(logP(Y,Z|\theta)\)关于给定观测数据\(Y\)和当前参数\(\theta^{(i)}\)下对未观测数据\(Z\)的条件概率分布\(P(Z|Y,\theta^{(i)})\)的期望称为\(Q\)函数,即:
\(Q(\theta,\theta^{(i)}=E_Z[logP(Y,Z|\theta)|Y,\theta^{(i)}]\)

4. EM算法的收敛性

1. 设\(P(Y|\theta)\)为观测数据的似然函数,\(\theta^{(i)}(i=1,2,...)\)为EM算法得到的参数估计序列,\(P(Y|\theta^{(i)}\)为对应的似然函数序列,则\(P(Y|\theta^{(i)}\)是单调递增的,即:

\(P(Y|\theta^{(i+1)} \ge P(Y|\theta^{(i)})\)

2. 设 \(L(\theta)=logP(Y|\theta)\)为观测数据的对数似然函数,\(\theta^{(i)}\)为EM算法得到的参数估计序列,\(L(\theta^{(i)})\)为对应的对数似然函数序列,则

(1) 如果\(P(Y|\theta)\)有上界,则\(L(\theta^{(i)})=logP(Y|\theta^{(i)})\)收敛到某一值\(L^{*}\);

(2) 在函数 \(Q(\theta,\theta^{'})\)\(L(\theta)\)满足一定的条件下,由EM算法得到的参数估计序列\(\theta^{(i)}\)的收敛值\(\theta^{*}\)\(L(\theta)\)的稳定点。

5. 高斯混合模型(GMM)

1. 高斯混合模型的定义:高斯混合模型是指具有如下形式的概率分布模型:

\(P(Y|\theta)=\sum_{k=1}^{K}\alpha_k\phi(y|\theta_k)\)
其中\(\alpha_k\)是系数,\(\alpha_k \ge 0\),\(\sum_{k=1}^{K}\alpha_k=1\);\(\phi(y|\theta_k)\)是高斯分布(正态分布)密度函数,\(\theta_k=(\mu_k,\sigma_k^{2})\),
\(\phi(y|\theta_k)=\frac{1}{\sqrt(2\pi)\sigma_k}exp(-\frac{(y-\mu_k)^2}{2\sigma_k^2})\) 称为第\(k\)个分模型。

2. 使用EM算法估计高斯混合模型

  具体推导过程从略,可以参见《统计学习方法》。这里直接给出结果:
高斯混合模型的EM估计算法
输入:观测数据\(y_1,y_2,...,y_N\),高斯混合模型;
输出:高斯混合模型的参数。
(1)取参数的初始值迭代
(2)E步,依据当前模型的参数,计算分模型k对观测数据\(y_j\)的响应度,
\(\hat\gamma_{jk}=\frac{\alpha_k\phi(y_j|\theta_k)}{\sum_{k=1}^{K}\alpha_k\phi(y_j|\theta_k)},j=1,2,...,N,k=1,2,...K\)
(3)计算新一轮迭代的模型参数
\(\hat\mu_k=\frac{\sum_{j=1}^{N}\hat\gamma_{jk}y_j}{\sum_{j=1}^{N}\hat\gamma_{jk}},k=1,2,...,K\)
\(\hat\sigma_k^2=\frac{\sum_{j=1}^{N}\hat\gamma_{jk}(y_j-\mu_k)^2}{\sum_{j=1}^{N}\hat\gamma_{jk}},k=1,2,...,K\)
\(\hat\alpha_k=\frac{\sum_{j=1}^{N}\hat\gamma_{jk}}{N},k=1,2,...,K\)
(4)重复(2)和(3)直到收敛。

6. python 实现 EM算法求解高斯混合模型

求解的步骤参考5,这里直接给出代码如下:

from __future__ import print_function
import numpy as np


def generateData(k,mu,sigma,dataNum):
    '''
    产生混合高斯模型的数据
    :param k: 比例系数
    :param mu: 均值
    :param sigma: 标准差
    :param dataNum:数据个数
    :return: 生成的数据
    '''
    # 初始化数据
    dataArray = np.zeros(dataNum,dtype=np.float32)
    # 逐个依据概率产生数据
    # 高斯分布个数
    n = len(k)
    for i in range(dataNum):
        # 产生[0,1]之间的随机数
        rand = np.random.random()
        Sum = 0
        index = 0
        while(index < n):
            Sum += k[index]
            if(rand < Sum):
                dataArray[i] = np.random.normal(mu[index],sigma[index])
                break
            else:
                index += 1
    return dataArray

def normPdf(x,mu,sigma):
    '''
    计算均值为mu,标准差为sigma的正态分布函数的密度函数值
    :param x: x值
    :param mu: 均值
    :param sigma: 标准差
    :return: x处的密度函数值
    '''
    return (1./np.sqrt(2*np.pi))*(np.exp(-(x-mu)**2/(2*sigma**2)))



def em(dataArray,k,mu,sigma,step = 10):
    '''
    em算法估计高斯混合模型
    :param dataNum: 已知数据个数
    :param k: 每个高斯分布的估计系数
    :param mu: 每个高斯分布的估计均值
    :param sigma: 每个高斯分布的估计标准差
    :param step:迭代次数
    :return: em 估计迭代结束估计的参数值[k,mu,sigma]
    '''
    # 高斯分布个数
    n = len(k)
    # 数据个数
    dataNum = dataArray.size
    # 初始化gama数组
    gamaArray = np.zeros((n,dataNum))
    for s in range(step):
        for i in range(n):
            for j in range(dataNum):
                Sum = sum([k[t]*normPdf(dataArray[j],mu[t],sigma[t]) for t in range(n)])
                gamaArray[i][j] = k[i]*normPdf(dataArray[j],mu[i],sigma[i])/float(Sum)
        # 更新 mu
        for i in range(n):
            mu[i] = np.sum(gamaArray[i]*dataArray)/np.sum(gamaArray[i])
        # 更新 sigma
        for i in range(n):
            sigma[i] = np.sqrt(np.sum(gamaArray[i]*(dataArray - mu[i])**2)/np.sum(gamaArray[i]))
        # 更新系数k
        for i in range(n):
            k[i] = np.sum(gamaArray[i])/dataNum

    return [k,mu,sigma]





if __name__ == '__main__':
    # 参数的准确值
    k = [0.3,0.4,0.3]
    mu = [2,4,3]
    sigma = [1,1,4]
    # 样本数
    dataNum = 5000
    # 产生数据
    dataArray = generateData(k,mu,sigma,dataNum)
    # 参数的初始值
    # 注意em算法对于参数的初始值是十分敏感的
    k0 = [0.3,0.3,0.4]
    mu0 = [1,2,2]
    sigma0 = [1,1,1]
    step = 6
    # 使用em算法估计参数
    k1,mu1,sigma1 = em(dataArray,k0,mu0,sigma0,step)
    # 输出参数的值
    print("参数实际值:")
    print("k:",k)
    print("mu:",mu)
    print("sigma:",sigma)
    print("参数估计值:")
    print("k1:",k1)
    print("mu1:",mu1)
    print("sigma1:",sigma1)

运行结果如下:
img_1d11d2d2d5f25496266f6da2f4e4e3c8.png
如果改变初始参数为:

k0=[0.1,0.4,0.5]
mu0=[3,3,2]
sigma0=[2,1,1.5]

那么得到的结果为:
img_7343a2c0c90c61b57a7bdccfa677b0f3.png
  对比可以发现,两次参数估计的结果相差还是很大的。这也说明了EM算法对初始参数的敏感性。在实际应用的过程中,我们可以通过多次试验验证来选择估计结果最好的那个作为初始参数。

热爱编程,热爱机器学习! github:http://www.github.com/Lyrichu github blog:http://Lyrichu.github.io 个人博客站点:http://www.movieb2b.com(不再维护)
目录
相关文章
|
24天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
240 55
|
12天前
|
存储 缓存 监控
局域网屏幕监控系统中的Python数据结构与算法实现
局域网屏幕监控系统用于实时捕获和监控局域网内多台设备的屏幕内容。本文介绍了一种基于Python双端队列(Deque)实现的滑动窗口数据缓存机制,以处理连续的屏幕帧数据流。通过固定长度的窗口,高效增删数据,确保低延迟显示和存储。该算法适用于数据压缩、异常检测等场景,保证系统在高负载下稳定运行。 本文转载自:https://www.vipshare.com
103 66
|
2天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
眼疾识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了4种常见的眼疾图像数据集(白内障、糖尿病性视网膜病变、青光眼和正常眼睛) 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,实现用户上传一张眼疾图片识别其名称。
16 4
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
|
2月前
|
搜索推荐 Python
利用Python内置函数实现的冒泡排序算法
在上述代码中,`bubble_sort` 函数接受一个列表 `arr` 作为输入。通过两层循环,外层循环控制排序的轮数,内层循环用于比较相邻的元素并进行交换。如果前一个元素大于后一个元素,就将它们交换位置。
146 67
|
2月前
|
存储 搜索推荐 Python
用 Python 实现快速排序算法。
快速排序的平均时间复杂度为$O(nlogn)$,空间复杂度为$O(logn)$。它在大多数情况下表现良好,但在某些特殊情况下可能会退化为最坏情况,时间复杂度为$O(n^2)$。你可以根据实际需求对代码进行调整和修改,或者尝试使用其他优化策略来提高快速排序的性能
131 61
|
1月前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
168 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
2天前
|
存储 算法 Serverless
剖析文件共享工具背后的Python哈希表算法奥秘
在数字化时代,文件共享工具不可或缺。哈希表算法通过将文件名或哈希值映射到存储位置,实现快速检索与高效管理。Python中的哈希表可用于创建简易文件索引,支持快速插入和查找文件路径。哈希表不仅提升了文件定位速度,还优化了存储管理和多节点数据一致性,确保文件共享工具高效运行,满足多用户并发需求,推动文件共享领域向更高效、便捷的方向发展。
|
16天前
|
存储 运维 监控
探索局域网电脑监控软件:Python算法与数据结构的巧妙结合
在数字化时代,局域网电脑监控软件成为企业管理和IT运维的重要工具,确保数据安全和网络稳定。本文探讨其背后的关键技术——Python中的算法与数据结构,如字典用于高效存储设备信息,以及数据收集、异常检测和聚合算法提升监控效率。通过Python代码示例,展示了如何实现基本监控功能,帮助读者理解其工作原理并激发技术兴趣。
50 20
|
9天前
|
算法 网络协议 Python
探秘Win11共享文件夹之Python网络通信算法实现
本文探讨了Win11共享文件夹背后的网络通信算法,重点介绍基于TCP的文件传输机制,并提供Python代码示例。Win11共享文件夹利用SMB协议实现局域网内的文件共享,通过TCP协议确保文件传输的完整性和可靠性。服务器端监听客户端连接请求,接收文件请求并分块发送文件内容;客户端则连接服务器、接收数据并保存为本地文件。文中通过Python代码详细展示了这一过程,帮助读者理解并优化文件共享系统。
|
14天前
|
存储 算法 Python
文件管理系统中基于 Python 语言的二叉树查找算法探秘
在数字化时代,文件管理系统至关重要。本文探讨了二叉树查找算法在文件管理中的应用,并通过Python代码展示了其实现过程。二叉树是一种非线性数据结构,每个节点最多有两个子节点。通过文件名的字典序构建和查找二叉树,能高效地管理和检索文件。相较于顺序查找,二叉树查找每次比较可排除一半子树,极大提升了查找效率,尤其适用于海量文件管理。Python代码示例包括定义节点类、插入和查找函数,展示了如何快速定位目标文件。二叉树查找算法为文件管理系统的优化提供了有效途径。
46 5