从SSM到S4
参考视频:Mamba 超超超详细解说 | 1、对 SSM 的透彻理解_哔哩哔哩_bilibili
以下内容为上述视频的学习记录,详情可看视频。
一、SSM:State Space Model
- 状态空间模型(State Space Model,SSM)是一种数学框架,用于描述动态系统的行为。
通过描述系统的状态变量及其动态变化规律,建模系统的状态随时间的演变以及观测数据的变化。
1.1 组成部分
- 状态方程(State Equation):描述系统的内部状态如何随时间演变。
- 输出方程(Output Equation):描述外部可观测的输出如何与系统状态相关联。
1.2 什么是状态State
状态是系统当前所处的条件或信息,用以确定其未来行为。
1.3 连续SSM--Continuous State Space Model
时间是连续的,状态随时间的变化用微分方程表示。
状态方程:
:状态变量的时间导数,表示系统随时间的变化率。
:状态变量,表示系统在时刻 t 的状态。
:输入变量,表示系统的外部控制输入。
:状态转移矩阵,描述状态随时间变化的动态关系。
:控制矩阵,描述输入变量对状态变化的影响。
输出方程:
:输出变量,表示系统的观测值。
:输出矩阵,描述状态对系统输出的影响。
:直连矩阵,描述输入变量直接对输出的影响。
1.4 离散SSM--Discrete State Space Model
时间是离散的,状态随时间的变化用差分方程表示。
状态方程:
:系统在离散时刻 k的状态。
:系统在下一时刻 k+1的状态。
:在时刻 k 的输入变量。
:状态转移矩阵,描述状态从 k到 k+1 的转移关系。
:控制矩阵。
输出方程:
:系统在时刻 k 的输出。
:输出矩阵,描述状态对系统输出的影响。
:直连矩阵,描述输入对输出的直接影响。
1.5 连续SSM和离散SSM的区别
1.6 SSM举例(理解系统与状态)
1.6.1 离散SSM--养鱼场例子
1.6.2 连续SSM--弹簧系统例子
1.7 SSM的离散化
1.7.1 为什么要离散化
- 对于连续的系统,计算机没办法处理,计算机处理数据和执行算法时通常使用离散的时间步长。离散化将连续时间系统转换为离散时间系统,使其能够在计算机上进行数值模拟和计算。
- 通过 连续SSM 的状态方程,我们可以知道任意时刻:输入、状态、状态微分(状态对时间的导数) 的关系,但没办法根据上一时刻的状态推测下一时刻的状态。离散化后,可以将系统表示为 状态的递推公式,可逐步递推系统状态。
1.7.2 离散化方法
离散化即使用方法近似公式(3)中的
1.7.2.1 前向欧拉法
1.7.2.2 后向欧拉法
1.7.2.3 梯形法
1.7.2.4 零阶保持 Zero-Order Hold
零阶保持(Zero-Order Hold,ZOH)用于将离散时间信号转换为连续时间信号。这种方法假设输入信号在每个采样周期内保持恒定不变,即在每个采样点之后,直到下一个采样点到来之前,信号的值不再变化。
1.7.3 弹簧系统离散化代码示例
1.7.3.1 定义状态空间模型SSM
函数 example_mass(k, b, m)
def example_mass(k, b, m):
A = np.array([[0, 1], [-k / m, -b / m]])
B = np.array([[0], [1.0 / m]])
C = np.array([[1.0, 0]])
return A, B, C
- 输入参数:弹簧常数
k
,阻尼系数b
,质量m
。 - 矩阵定义:
A
:状态转移矩阵,描述系统的动态性质。B
:控制矩阵,描述输入对状态的影响。C
:输出矩阵,描述状态如何影响输出
函数 example_force(t)
@partial(np.vectorize, signature='()->()')
def example_force(t):
x = np.sin(10 * t)
return x * (x > 0.5)
- 描述:定义一个时间函数
u(t)
,该函数为正弦函数的变体,仅在sin(10 * t)
大于 0.5 时有效。 - 作用:为系统提供外部输入
u(t)
。
1.7.3.2 离散化
函数 discretize(A, B, C, step)
def discretize(A, B, C, step):
I = np.eye(A.shape[0])
BL = inv(I - (step / 2.0) * A)
Ab = (I + (step / 2.0) * A) @ BL
Bb = (BL @ step) @ B
return Ab, Bb, C
- 输入参数:连续时间矩阵
A
、B
、C
和离散时间步长step
。 - 输出:离散时间的矩阵
Ab
、Bb
和C
。 - 方法:使用双线性变换(梯形积分法)进行离散化,这种方法可以提高离散化的准确性。
Ab
:离散化后的状态转移矩阵。Bb
:离散化后的控制矩阵。
1.7.3.3 运行状态空间模型
函数 run_SSM(A, B, C, u)
def run_SSM(A, B, C, u):
L = u.shape[0]
N = A.shape[0]
Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]
- 输入参数:状态空间矩阵
A
、B
、C
和输入u
。 - 主要步骤:
- 计算离散化的矩阵
Ab
、Bb
和Cb
。 - 调用函数
scan_SSM
进行递归计算,返回系统的输出。
- 计算离散化的矩阵
辅助函数 scan_SSM(Ab, Bb, Cb, u, x0)
def scan_SSM(Ab, Bb, Cb, u, x0):
def step(x_k, u_k):
x_k = Ab @ x_k + Bb @ u_k
y_k = Cb @ x_k
return x_k, y_k
return jax.lax.scan(step, x0, u)
- 描述:使用
jax.lax.scan
实现递归计算,模拟离散时间状态空间系统的动态行为。 step
函数:在每个时间步长上更新状态x_k
和计算输出y_k
。
1.7.3.4 运行示例
def example_ssm():
ssm = example_mass(k=40, b=5, m=1)
# Samples of u(t)
L = 100
step = 1.0 / L
ks = np.arange(L)
u = example_force(ks * step)
# Approximation of y(t)
y = run_SSM(*ssm, u)
# Plotting
import matplotlib.pyplot as plt
import seaborn
from celluloid import Camera
seaborn.set_context("paper")
fig, (ax1, ax2, ax3) = plt.subplots(3)
camera = Camera(fig)
ax1.set_title("Force $u_k$")
ax2.set_title("Position $y_k$")
ax3.set_title("Object")
ax1.set_xticks([], [])
ax2.set_xticks([], [])
ax3.set_xticks([], [])
# Animate plot over time
for k in range(0, L, 2):
ax1.plot(ks[:k], u[:k], color="red")
ax2.plot(ks[:k], y[:k], color="blue")
ax3.boxplot(
[[y[k], -0.04, y[k], 0], [y[k], 0, y[k], 0.04]],
showcaps=False,
whis=False,
vert=False,
widths=0.1,
)
camera.snap()
anim = camera.animate()
anim.save("images/line.gif", dpi=150, writer="imagemagick")
if __name__ == "__main__":
example_ssm()
二、S4--Structured State Space for Sequences
Structured State Space for Sequences (S4)模型在训练和推理使用了不同形式,并且设计了Hippo矩阵作为SSM方程中的矩阵A。
由上述内容可知,离散SSM的公式如下:
2.1 SSM的RNN表示
按照timesteps展开表示:
如下图,可发现SSM和RNN的表达形式基本一致:
2.2 SSM的Convolution表示
因为文字序列是一维的,它的一维卷积表示如下:
卷积形式只有一个表达式:
2.3 RNN表示和Convolution表示的使用
2.4 Hippo矩阵
2.5 S4的参数化
NPLR(Normal Plus LowRank):正规矩阵(Normal Matrix) + 低秩矩阵(LowRank Matrix)