RNN-Gauss / RNN-GMM 模型的结构
RNN-Gauss 和 RNN-GMM 是专门为处理连续值时间序列数据(特别是具有复杂不确定性或多种模式的数据)而设计的概率生成模型。
它们本质上是将 RNN(作为强大的时间依赖关系捕捉器)与 高斯分布 或 高斯混合模型(作为灵活的输出分布建模器)相结合。
下图清晰地展示了这两种模型的核心架构与区别:
一、核心思想
传统的RNN在输出层通常使用一个简单的线性层或Softmax层,这限制了它们对复杂数据分布的建模能力。RNN-Gauss/GMM的核心创新在于:
- RNN(LSTM或GRU):处理序列输入,更新隐藏状态 hth_tht,捕捉历史信息的依赖关系。
- 参数输出层:使用一个或多个全连接层,将隐藏状态 hth_tht 映射到概率分布的参数(如高斯分布的 μ\muμ 和 σ\sigmaσ)。
- 概率输出:在每一个时间步 ttt,模型不直接预测一个具体的值 yty_tyt,而是预测一个概率分布 P(yt∣x1:t)P(y_t | x_{1:t})P(yt∣x1:t)。最终的预测值从这个分布中采样或取其期望(如均值)得到。
这种结构使得模型不仅能给出预测值,还能量化预测的不确定性(通过分布的方差)。
二、RNN-Gauss 模型结构
RNN-Gauss 假设在给定历史信息的情况下,下一个时间步的目标值 yty_tyt 服从一个高斯分布。
1. 网络结构
- 输入:时间序列数据 x1,x2,...,xTx_1, x_2, ..., x_Tx1,x2,...,xT。
- RNN层:可以是LSTM或GRU单元,处理输入序列,输出隐藏状态 hth_tht。
ht=RNN(xt,ht−1)h_t = \text{RNN}(x_t, h_{t-1})ht=RNN(xt,ht−1) - 参数输出层:一个全连接层,将隐藏状态 hth_tht 映射到高斯分布的两个参数:
μt=Wμht+bμ\mu_t = W_{\mu}h_t + b_{\mu}μt=Wμht+bμ
σt=softplus(Wσht+bσ)=log(1+exp(Wσht+bσ))\sigma_t = \text{softplus}(W_{\sigma}h_t + b_{\sigma}) = \log(1 + \exp(W_{\sigma}h_t + b_{\sigma}))σt=softplus(Wσht+bσ)=log(1+exp(Wσht+bσ))- 注意:使用
softplus
激活函数是为了保证标准差 σt\sigma_tσt 为正数。
- 注意:使用
2. 输出分布
在每一个时间步 ttt,模型定义了一个条件概率分布:
P(yt∣x1:t)=N(μt,σt2)P(y_t | x_{1:t}) = \mathcal{N}(\mu_t, \sigma_t^2)P(yt∣x1:t)=N(μt,σt2)
即 yty_tyt 服从均值为 μt\mu_tμt,方差为 σt2\sigma_t^2σt2 的高斯分布。
3. 损失函数:负对数似然
模型通过最大化观测数据的对数似然来训练,等价于最小化负对数似然损失。
L=−∑t=1TlogP(yt∣x1:t)=∑t=1T((yt−μt)22σt2+12log(2πσt2))\mathcal{L} = -\sum_{t=1}^T \log P(y_t | x_{1:t}) = \sum_{t=1}^T \left( \frac{(y_t - \mu_t)^2}{2\sigma_t^2} + \frac{1}{2}\log(2\pi\sigma_t^2) \right)L=−∑t=1TlogP(yt∣x1:t)=∑t=1T(2σt2(yt−μt)2+21log(2πσt2))
这个损失函数会同时优化均值 μt\mu_tμt 和方差 σt\sigma_tσt。当不确定性高时(σt\sigma_tσt 大),模型会对预测误差 (yt−μt)2(y_t - \mu_t)^2(yt−μt)2 的惩罚变小。
三、RNN-GMM 模型结构
RNN-GMM 是RNN-Gauss的扩展,它假设下一个时间步的目标值 yty_tyt 服从一个高斯混合模型。这适用于更复杂的情况,即数据在某个时间点可能存在多种可能的状态或模式。
1. 网络结构
- 输入和RNN层:与RNN-Gauss相同。
- 参数输出层:一个更大的全连接层,将隐藏状态 hth_tht 映射到GMM的所有参数。假设混合成分有 KKK 个:
- 混合权重 πtk\pi_t^kπtk:KKK 个值,使用Softmax确保它们和为1。
πt=softmax(Wπht+bπ)\boldsymbol{\pi}_t = \text{softmax}(W_{\pi}h_t + b_{\pi})πt=softmax(Wπht+bπ) - 均值 μtk\mu_t^kμtk:KKK 个值。
μtk=Wμkht+bμk\mu_t^k = W_{\mu^k}h_t + b_{\mu^k}μtk=Wμkht+bμk - 标准差 σtk\sigma_t^kσtk:KKK 个值,同样用softplus激活。
σtk=softplus(Wσkht+bσk)\sigma_t^k = \text{softplus}(W_{\sigma^k}h_t + b_{\sigma^k})σtk=softplus(Wσkht+bσk)
- 混合权重 πtk\pi_t^kπtk:KKK 个值,使用Softmax确保它们和为1。
2. 输出分布
在每一个时间步 ttt,模型定义的条件概率分布是多个高斯分布的加权和:
P(yt∣x1:t)=∑k=1Kπtk⋅N(μtk,(σtk)2)P(y_t | x_{1:t}) = \sum_{k=1}^K \pi_t^k \cdot \mathcal{N}(\mu_t^k, (\sigma_t^k)^2)P(yt∣x1:t)=∑k=1Kπtk⋅N(μtk,(σtk)2)
3. 损失函数:GMM负对数似然
L=−∑t=1Tlog(∑k=1Kπtk⋅N(yt;μtk,(σtk)2))\mathcal{L} = -\sum_{t=1}^T \log \left( \sum_{k=1}^K \pi_t^k \cdot \mathcal{N}(y_t; \mu_t^k, (\sigma_t^k)^2) \right)L=−∑t=1Tlog(∑k=1Kπtk⋅N(yt;μtk,(σtk)2))
这个损失函数会同时优化所有混合成分的权重、均值和方差。
四、应用场景
- 金融时间序列预测:股票价格、汇率等。市场未来走势具有内在不确定性,RNN-Gauss可以预测价格区间,RNN-GMM甚至可以预测“大涨”、“大跌”、“震荡”等多种可能模式。
- 语音信号生成:在TTS系统中,一个音素对应的声学特征可能有多种实现方式,GMM可以很好地建模这种多模态分布。
- 机器人运动规划:预测轨迹时,可能存在多条可行的路径。
- 气象预测:预测温度、风速等,天然需要提供不确定性估计。
总结
特征 | RNN-Gauss | RNN-GMM |
---|---|---|
输出分布 | 单峰高斯分布 | 多峰高斯混合模型 |
不确定性 | 量化不确定性(方差) | 量化不确定性,并能表示多模态 |
参数量 | 较少(每步输出2个参数) | 较多(每步输出 3K3K3K 个参数) |
适用场景 | 不确定性预测 | 多模态、多可能性的预测 |
这两种模型代表了将深度学习的表示学习能力与概率模型的严谨性相结合的成功范例,是处理复杂时间序列问题的强大工具。