期望最大化(Expectation Maximization,EM)
忘记啥时候看的了,一直在草稿箱里面,今天有空整理一下,方便以后查阅。
期望最大化(Expectation Maximization,EM)
- EM算法的基本思想
- 数学公式
- EM算法的具体步骤
- EM算法的应用
- 高斯混合模型(GMM)中的EM算法
- Python代码实现:高斯混合模型中的EM算法
- 代码解析
- 结果
- 总结
- 参考资料
期望最大化(EM)算法是一种用于含有隐变量(latent variables)或缺失数据(missing data)模型的参数估计方法。EM算法通过反复进行“期望”(E步)和“最大化”(M步)来寻找最大似然估计(MLE)。它通常用于需要估计潜在或隐含变量的概率模型,尤其在混合模型、聚类分析等问题中广泛应用。
EM算法的基本思想
EM算法通过迭代进行两步操作:
-
E步(Expectation,期望步):在当前参数估计下,计算隐变量的期望值,或者说是对隐变量进行“填补”。它计算隐变量的条件分布,基于观察到的数据和当前的模型参数。
-
M步(Maximization,最大化步):根据E步计算出的隐变量期望值,最大化完全数据的似然函数,更新模型参数。
EM算法的目标是最大化观察数据的似然函数。由于模型中存在隐变量,直接计算似然函数比较困难,因此EM算法通过交替进行E步和M步来近似计算最大似然估计。
数学公式
假设我们有数据 XXX,隐变量为 ZZZ,模型参数为 θ\thetaθ,我们的目标是最大化对数似然函数:
L(θ)=logP(X∣θ)\mathcal{L}(\theta) = \log P(X \mid \theta) L(θ)=logP(X∣θ)
由于隐变量 ZZZ 是不可观察的,直接计算似然函数比较困难。EM算法通过引入隐变量的完整数据集来进行求解。完整数据的似然函数为:
L(θ,Z)=logP(X,Z∣θ)\mathcal{L}(\theta, Z) = \log P(X, Z \mid \theta) L(θ,Z)=logP(X,Z∣θ)
EM算法通过以下两个步骤进行迭代:
-
E步:计算在当前参数估计下,隐变量的条件概率分布 Q(θ,θ(t))=EZ[L(θ,Z)]Q(\theta, \theta^{(t)}) = \mathbb{E}_Z[\mathcal{L}(\theta, Z)]Q(θ,θ(t))=EZ[L(θ,Z)],即计算隐变量的期望值。
-
M步:最大化期望似然函数 Q(θ,θ(t))Q(\theta, \theta^{(t)})Q(θ,θ(t)),找到更新后的参数 θ(t+1)\theta^{(t+1)}θ(t+1)。
EM算法的具体步骤
- 初始化:选择初始的参数估计 θ(0)\theta^{(0)}θ(0)。
- E步:计算隐变量的条件概率(期望),即 P(Z∣X,θ(t))P(Z \mid X, \theta^{(t)})P(Z∣X,θ(t))。
- M步:通过最大化期望似然函数Q(θ,θ(t))Q(\theta, \theta^{(t)})Q(θ,θ(t))来更新模型参数。
- 重复:重复E步和M步,直到参数收敛。
EM算法的应用
EM算法的一个典型应用是高斯混合模型(Gaussian Mixture Model,GMM)。GMM模型假设数据是由多个高斯分布混合而成,其中每个高斯分布对应于一个隐含的类别。通过EM算法,我们可以估计GMM模型的参数(每个高斯分布的均值、方差和混合系数)。
高斯混合模型(GMM)中的EM算法
-
E步:计算每个数据点属于每个高斯分布的概率,称为责任度(responsibility)。也就是给定当前模型参数,计算每个数据点属于每个高斯分布的后验概率。
-
M步:根据E步计算的责任度,重新估计每个高斯分布的均值、协方差和混合系数。
Python代码实现:高斯混合模型中的EM算法
下面的代码实现了一个简单的EM算法,来估计一个高斯混合模型(GMM)的参数。
import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt# 生成高斯混合模型数据
np.random.seed(42)# 设置GMM的参数
means = np.array([[3, 3], [-3, -3]]) # 两个高斯分布的均值
covariances = np.array([[[1, 0], [0, 1]], [[1, 0], [0, 1]]]) # 两个高斯分布的协方差
weights = np.array([0.5, 0.5]) # 混合系数# 生成1000个数据点
n_samples = 1000
n_components = len(means)
X = np.vstack([np.random.multivariate_normal(means[i], covariances[i], int(weights[i] * n_samples))for i in range(n_components)
])# 可视化生成的数据
plt.scatter(X[:, 0], X[:, 1], s=10, alpha=0.5)
plt.title("Generated Data from GMM")
plt.show()# EM算法实现
def em_gmm(X, n_components, max_iter=100, tol=1e-6):n_samples, n_features = X.shape# 初始化GMM参数np.random.seed(42)means = np.random.randn(n_components, n_features)covariances = np.array([np.eye(n_features)] * n_components)weights = np.ones(n_components) / n_componentslog_likelihoods = []for i in range(max_iter):# E步:计算责任度resp = np.zeros((n_samples, n_components))for j in range(n_components):resp[:, j] = weights[j] * multivariate_normal.pdf(X, means[j], covariances[j])# 归一化责任度resp = resp / resp.sum(axis=1)[:, np.newaxis]# M步:更新参数N_k = resp.sum(axis=0) # 每个组件的总责任度weights = N_k / n_samples # 更新混合系数means = np.dot(resp.T, X) / N_k[:, np.newaxis] # 更新均值covariances = np.array([np.dot((resp[:, k] * (X - means[k])).T, X - means[k]) / N_k[k]for k in range(n_components)]) # 更新协方差# 计算对数似然log_likelihood = np.sum(np.log(np.sum(resp, axis=1)))log_likelihoods.append(log_likelihood)# 收敛判断if i > 0 and abs(log_likelihood - log_likelihoods[-2]) < tol:breakreturn means, covariances, weights, log_likelihoods# 使用EM算法拟合数据
means, covariances, weights, log_likelihoods = em_gmm(X, n_components=2)# 打印最终的参数估计
print("Means:\n", means)
print("\nCovariances:\n", covariances)
print("\nWeights:\n", weights)# 绘制对数似然的收敛曲线
plt.plot(log_likelihoods)
plt.title("Log-Likelihood Convergence")
plt.xlabel("Iteration")
plt.ylabel("Log-Likelihood")
plt.show()
代码解析
-
数据生成:
- 生成了一个由两个二维高斯分布混合而成的数据集。每个高斯分布的均值和协方差矩阵是事先设定的。
-
EM算法实现:
- 初始化参数:均值、协方差和混合系数的初始值随机生成。
- E步:计算每个数据点属于每个高斯分布的责任度(后验概率),责任度是基于当前参数和数据计算出来的。
- M步:根据责任度更新参数。即计算每个高斯分布的均值、协方差和混合系数。
- 对数似然计算:每次迭代后,计算对数似然值,以检查算法是否收敛。
- 收敛判断:当两次迭代之间的对数似然差小于设定的阈值时,停止迭代。
-
结果可视化:
- 可视化生成的数据分布。
- 绘制对数似然的收敛曲线,观察EM算法是否收敛。
结果
- 在代码运行后,你将得到两个高斯分布的估计参数(均值、协方差和混合系数)。这些估计的参数应该与我们生成数据时所设定的参数相近。
- 收敛图展示了EM算法在每次迭代中的对数似然变化,通常会看到它趋于稳定。
总结
EM算法是一种非常强大的参数估计方法,特别适用于含有隐变量的模型。它的核心思想是通过迭代优化“期望”(E步)
参考资料
机器学习实验报告——EM算法
机器学习 | 深入理解EM算法