机器学习 [白板推导](八)[EM算法]
10. EM期望最大算法(Expectation Maximization)
10.1 算法简介
概率模型求解时,很多时候数据只有观测变量,因为缺少隐变量而无法求得解析解,也很难通过极大似然估计法对完整的模型参数进行估计。
EM算法即为含有隐变量的极大似然估计法(或含有隐变量的极大后验概率估计法)。
举例:
- 有并不均匀的 AAA,BBB 和 CCC 三个硬币,正面朝上的概率分别为 ooo,ppp 和 qqq ,首先抛掷 AAA,若正面朝上就抛掷 BBB,若反面朝上就抛掷 CCC,记录最终的抛掷结果为 xi=1 or 0x_i=1\ \text{or}\ 0xi=1 or 0,但硬币 AAA 的抛掷结果作为隐变量未能观测,使得模型参数 ooo,ppp,qqq 不能精准求解。
- 若使用极大似然估计法,则为
θ=argmaxθ logP(X∣θ)\theta=\underset{\theta}{\arg \max} \ \log P(X|\theta)θ=θargmax logP(X∣θ),忽略了隐变量,但由于 P(X∣θ)=∫ZP(X,Z∣θ)dZ=∫ZP(X∣Z,θ)⋅P(Z∣θ)dZ=EZ∣θ[P(X∣Z,θ)],(10.1)\begin{aligned} P(X|\theta)&=\int_Z P(X,Z|\theta) dZ\\&=\int_Z P(X|Z,\theta)\cdot P(Z|\theta) dZ\\&=E_{Z|\theta}[P(X|Z,\theta)],\tag{10.1} \end{aligned}P(X∣θ)=∫ZP(X,Z∣θ)dZ=∫ZP(X∣Z,θ)⋅P(Z∣θ)dZ=EZ∣θ[P(X∣Z,θ)],(10.1)
因此可以使用生成式概率模型的思想,认为 XXX 是由 ZZZ 生成的,来对模型参数进行求解
EM算法分为两步,E步是对 P(X∣θ)P(X|\theta)P(X∣θ) 求期望,M步是迭代模型参数 θ\thetaθ 使得期望最大。
10.2 算法收敛性证明
首先直接看公式
θ(t+1)=argmaxθ Ez∣x,θ(t)[logP(x,z∣θ)]=argmaxθ ∫z[logP(x,z∣θ)⋅P(z∣x,θ(t))]dz,(10.2)
\begin{aligned}
\theta^{(t+1)}&=\underset{\theta}{\arg \max} \ E_{z|x,\theta^{(t)}}\left [\log P(x, z|\theta)\right ]\\&=\underset{\theta}{\arg \max} \ \int_z\left [\log P(x, z|\theta)\cdot P(z|x,\theta^{(t)})\right ]dz
, \tag{10.2}\end{aligned}
θ(t+1)=θargmax Ez∣x,θ(t)[logP(x,z∣θ)]=θargmax ∫z[logP(x,z∣θ)⋅P(z∣x,θ(t))]dz,(10.2)
其中:
- xxx 是样本数据,zzz 是隐变量,θ\thetaθ 是模型参数;
- logP(x,z∣θ)\log P(x,z|\theta)logP(x,z∣θ)被称为完备数据概率,因为其包含了 xxx 和 zzz 的联合概率分布;
- p(z∣x,θ(t))p(z|x,\theta^{(t)})p(z∣x,θ(t)) 是后验概率。
这个公式是通过调整参数 θ\thetaθ,使得 x,zx,zx,z 的完备数据的联合概率密度期望最大,因此被称为期望最大算法,核心思路类似于最大似然估计。
具体来看,参数是逐步更新的,也就是每次给定当前的参数 θ(t)\theta^{(t)}θ(t),调整 θ\thetaθ 使得完备数据概率的期望最大,并将其赋值给 θ(t+1)\theta^{(t+1)}θ(t+1),即完成了参数更新,直至收敛。
接下来看收敛性证明:
- 根据极大似然估计思想,若保证 logP(x∣θ(t))⩽logP(x∣θ(t+1))\log P(x|\theta^{(t)})\leqslant \log P(x|\theta^{(t+1)})logP(x∣θ(t))⩽logP(x∣θ(t+1)),则可以保证EM算法的有效性和收敛性。
- 已知 logP(x∣θ)=logP(x,z∣θ)P(z∣x,θ)=logP(x,z∣θ)−logP(z∣x,θ)\log P(x|\theta)=\log \frac{P(x,z|\theta)}{P(z|x,\theta)}=\log P(x,z|\theta)-\log P(z|x,\theta)logP(x∣θ)=logP(z∣x,θ)P(x,z∣θ)=logP(x,z∣θ)−logP(z∣x,θ),将两边对 zzz 积分:
- 等号左边因为所有变量都与 zzz 无关,所以积分值为 logP(x∣θ)⋅1=logP(x∣θ)\log P(x|\theta)\cdot 1=\log P(x|\theta)logP(x∣θ)⋅1=logP(x∣θ)。
- 等号右边为 ∫zP(z∣x,θ(t))⋅logP(x,z∣θ)dz−∫zP(z∣x,θ(t))⋅logP(z∣x,θ)dz\int_z P(z|x,\theta^{(t)})\cdot \log P(x,z|\theta)dz-\int_z P(z|x,\theta^{(t)})\cdot \log P(z|x,\theta)dz∫zP(z∣x,θ(t))⋅logP(x,z∣θ)dz−∫zP(z∣x,θ(t))⋅logP(z∣x,θ)dz,分别记作 Q(θ,θ(t))Q(\theta,\theta^{(t)})Q(θ,θ(t)) 和 H(θ,θ(t))H(\theta,\theta^{(t)})H(θ,θ(t)),因此等号右边为 Q(θ,θ(t))−Q(θ,θ(t))Q(\theta,\theta^{(t)})-Q(\theta,\theta^{(t)})Q(θ,θ(t))−Q(θ,θ(t))。
- 若能保证 Q(θ(t+1),θ(t))⩾Q(θ,θ(t))Q(\theta^{(t+1)},\theta^{(t)})\geqslant Q(\theta,\theta^{(t)})Q(θ(t+1),θ(t))⩾Q(θ,θ(t)),H(θ(t+1),θ(t))⩽H(θ,θ(t))H(\theta^{(t+1)},\theta^{(t)})\leqslant H(\theta,\theta^{(t)})H(θ(t+1),θ(t))⩽H(θ,θ(t)),则可以保证 logP(x∣θ(t))⩽logP(x∣θ(t+1))\log P(x|\theta^{(t)})\leqslant \log P(x|\theta^{(t+1)})logP(x∣θ(t))⩽logP(x∣θ(t+1)),从而保证期望上升。
- Q(θ,θ(t))Q(\theta,\theta^{(t)})Q(θ,θ(t)) 即为定义(第一行的公式),因此 Q(θ(t+1),θ(t))Q(\theta^{(t+1)},\theta^{(t)})Q(θ(t+1),θ(t)) 即为 Q(θ,θ(t))Q(\theta,\theta^{(t)})Q(θ,θ(t)) 的最大值,因此 Q(θ(t+1),θ(t))⩾Q(θ,θ(t))Q(\theta^{(t+1)},\theta^{(t)})\geqslant Q(\theta,\theta^{(t)})Q(θ(t+1),θ(t))⩾Q(θ,θ(t)) 恒成立。
- 另外有
H(θ(t+1),θ(t))−H(θ(t),θ(t))=∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t+1))dz−∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t))dz=∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t+1))dzP(z∣x,θ(t))dz=−KL(P(z∣x,θ(t)) ∥ P(z∣x,θ(t)))⩽0,(10.3) \begin{aligned} &H(\theta^{(t+1)},\theta^{(t)}) - H(\theta^{(t)},\theta^{(t)})\\ =& \int_z P(z|x,\theta^{(t)})\cdot \log P(z|x,\theta^{(t+1)})dz-\int_z P(z|x,\theta^{(t)})\cdot \log P(z|x,\theta^{(t)})dz\\ =&\int_z P(z|x,\theta^{(t)})\cdot \log \frac{P(z|x,\theta^{(t+1)})dz}{P(z|x,\theta^{(t)})dz}=-KL\left (P(z|x,\theta^{(t)})\ \|\ P(z|x,\theta^{(t)}) \right )\leqslant 0, \tag{10.3} \end{aligned} ==H(θ(t+1),θ(t))−H(θ(t),θ(t))∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t+1))dz−∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t))dz∫zP(z∣x,θ(t))⋅logP(z∣x,θ(t))dzP(z∣x,θ(t+1))dz=−KL(P(z∣x,θ(t)) ∥ P(z∣x,θ(t)))⩽0,(10.3)
(或者不从KL散度的角度,也可以用 E[logx]⩽logE[x]E[\log x ] \leqslant \log E[x]E[logx]⩽logE[x] 证明),因此 H(θ(t+1),θ(t))⩽H(θ(t),θ(t))H(\theta^{(t+1)},\theta^{(t)})\leqslant H(\theta^{(t)},\theta^{(t)})H(θ(t+1),θ(t))⩽H(θ(t),θ(t)) 得证。
10.3 ELBO+KL散度导出公式
已知 logP(x∣θ)=logP(x,z∣θ)P(z∣x,θ)=logP(x,z∣θ)−logP(z∣x,θ)\log P(x|\theta)=\log \frac{P(x,z|\theta)}{P(z|x,\theta)}=\log P(x,z|\theta)-\log P(z|x,\theta)logP(x∣θ)=logP(z∣x,θ)P(x,z∣θ)=logP(x,z∣θ)−logP(z∣x,θ),设隐变量 zzz 的先验概率分布为 q(z)q(z)q(z),则:
logP(x∣θ)=logP(x,z∣θ)−logP(z∣x,θ)=logP(x,z∣θ)−logq(z)−[logP(z∣x,θ)−logq(z)]=logP(x,z∣θ)q(z)−logP(z∣x,θ)q(z),(10.4)
\begin{aligned}
\log P(x|\theta)&=\log P(x,z|\theta)-\log P(z|x,\theta)\\&=\log P(x,z|\theta)-\log q(z)-\left [\log P(z|x,\theta) -\log q(z) \right ] \\&=\log \frac{P(x,z|\theta)}{q(z)}-\log \frac{P(z|x,\theta) }{q(z)} , \tag{10.4}
\end{aligned}
logP(x∣θ)=logP(x,z∣θ)−logP(z∣x,θ)=logP(x,z∣θ)−logq(z)−[logP(z∣x,θ)−logq(z)]=logq(z)P(x,z∣θ)−logq(z)P(z∣x,θ),(10.4)
将上式两边对 zzz 积分:
- 左边 =∫zlogP(x∣θ)⋅q(z)dz=logP(x∣θ)=\int _z \log P(x|\theta) \cdot q(z) dz=\log P(x|\theta)=∫zlogP(x∣θ)⋅q(z)dz=logP(x∣θ)。
- 右边 =∫zlogP(x,z∣θ)q(z)⋅q(z)dz−∫zlogP(z∣x,θ)q(z)⋅q(z)dz=ELBO−KL[ q(z) ∣∣ P(z∣x,θ) ]=\int_z \log \frac{P(x,z|\theta)}{q(z)} \cdot q(z) dz-\int_z \log \frac{P(z|x,\theta)}{q(z)} \cdot q(z) dz=ELBO-KL\left [\ q(z)\ ||\ P(z|x,\theta) \ \right ]=∫zlogq(z)P(x,z∣θ)⋅q(z)dz−∫zlogq(z)P(z∣x,θ)⋅q(z)dz=ELBO−KL[ q(z) ∣∣ P(z∣x,θ) ],其中ELBO\text{ELBO}ELBO 为证据下界evidence lower bound;
因此当且仅当 q(z)==P(z∣x,θ)q(z)==P(z|x,\theta)q(z)==P(z∣x,θ) 时,KL[ q(z) ∣∣ P(z∣x,θ) ]=0KL\left [\ q(z)\ ||\ P(z|x,\theta) \ \right ]=0KL[ q(z) ∣∣ P(z∣x,θ) ]=0,logP(x∣θ)=ELBO\log P(x|\theta)=ELBOlogP(x∣θ)=ELBO,此时优化 ELBO\text{ELBO}ELBO 即为优化 logP(x∣θ)\log P(x|\theta)logP(x∣θ),但这种方法并不保证可以找到 logP(x∣θ)\log P(x|\theta)logP(x∣θ) 的全局最大值;
10.4 Jensen不等式导出公式
根据Jensen不等式,f(E(x))⩾E[f(x)]f(E(x))\geqslant E[f(x)]f(E(x))⩾E[f(x)],
logP(x∣θ)=log[∫zP(x,z∣θ)dz]=log[∫zP(x,z∣θ)q(z)⋅q(z)dz]=logEq(z)[P(x,z∣θ)q(z)]⩾Eq(z)[logP(x,z∣θ)q(z)],(10.5)
\begin{aligned} \log P(x|\theta)&=\log \left [\int _z P(x,z|\theta) dz \right ] =\log \left [\int _z \frac{P(x,z|\theta) }{q(z)}\cdot q(z) dz \right ] \\ &=\log E_{q(z)}\left [ \frac{P(x,z|\theta)}{q(z)} \right ]\geqslant E_{q(z)}\left [ \log\frac{P(x,z|\theta)}{q(z)} \right ] ,\tag{10.5}\end{aligned}
logP(x∣θ)=log[∫zP(x,z∣θ)dz]=log[∫zq(z)P(x,z∣θ)⋅q(z)dz]=logEq(z)[q(z)P(x,z∣θ)]⩾Eq(z)[logq(z)P(x,z∣θ)],(10.5)
当且仅当 P(x,z∣θ)q(z)=C\frac{P(x,z|\theta)}{q(z)}=Cq(z)P(x,z∣θ)=C 时等号成立。
左右变换得P(x,z∣θ)C=q(z)\begin{aligned}\frac{P(x,z|\theta)}{C}=q(z) \end{aligned}CP(x,z∣θ)=q(z),两边对 zzz 积分,得 1C∫zP(x,z∣θ)dz=1CP(x∣θ)=∫zq(z)dz=1\frac{1}{C}\int_z P(x,z|\theta) dz=\frac{1}{C}P(x|\theta)=\int_z q(z) dz=1C1∫zP(x,z∣θ)dz=C1P(x∣θ)=∫zq(z)dz=1,即 P(x∣θ)=CP(x|\theta)=CP(x∣θ)=C,因此 q(z)=P(x,z∣θ)P(x∣θ)=P(z∣x,θ)q(z)=\frac{P(x,z|\theta)}{P(x|\theta)}=P(z|x,\theta)q(z)=P(x∣θ)P(x,z∣θ)=P(z∣x,θ),与ELBO+KL散度得到相同结论。
10.5 广义EM
根据上面的推导,EM算法的目标函数 logP(X∣θ)\log P(X|\theta)logP(X∣θ) 等效为 ELBO: Eq(z)[logP(X,Z∣θ)q(z)]=Eq(z)[logP(X,Z∣θ)]−∫zlogq(z)dz\text{ELBO: }E_{q(z)}[\log \frac{P(X,Z|\theta)}{q(z)}]=E_{q(z)}[\log P(X,Z|\theta)]-\int_z\log q(z)dzELBO: Eq(z)[logq(z)P(X,Z∣θ)]=Eq(z)[logP(X,Z∣θ)]−∫zlogq(z)dz,当且仅当 q(z)==P(z∣x,θ)q(z)==P(z|x,\theta)q(z)==P(z∣x,θ) 时,ELBO\text{ELBO}ELBO 达到最大值 ELBO=logP(x∣θ)\text{ELBO}=\log P(x|\theta)ELBO=logP(x∣θ)。
但在某些场景中,后验概率 P(z∣x,θ)P(z|x,\theta)P(z∣x,θ) 无法直接求解,因此可以对 q(z)==P(z∣x,θ)q(z)==P(z|x,\theta)q(z)==P(z∣x,θ) 进行同步迭代,即为广义EM算法:
- E步:q(t+1)(z)=arg maxq ELBOq^{(t+1)}(z)=\underset{q}{\argmax}\ \text{ELBO}q(t+1)(z)=qargmax ELBO;
- M步:θ(t+1)=arg maxθ ELBO\theta^{(t+1)}=\underset{\theta}{\argmax}\ \text{ELBO}θ(t+1)=θargmax ELBO;