mamba学习1
1. Mamba
优点:
- 计算效率高:Mamba的计算复杂度与序列长度呈线性或近线性关系,相比Transformer的二次方计算复杂度,在处理长序列数据时具有显著优势。例如在A100 GPU上,Mamba的计算速度可提升3倍;
- 选择性处理信息:引入选择机制,可根据输入参数化SSM参数,过滤无关信息,保留必要数据,使模型能够专注于对当前任务更重要的部分;
- 硬件感知算法:采用递归扫描而非卷积计算,优化硬件性能,减少GPU中SRAM和DRAM之间的数据传输次数,进一步提升计算效率;
- 高效的训练和推理:平行训练时,使用卷积;推理时,使用递归。
1.1 State Space Model(SSM)
SSM的简化模型如上图,其中输入是x,输出是y,隐藏状态是h。B乘x得到h,之后A乘h更新h,C乘h得到y,D乘x跳跃连接得到y。公式如下:
h
′
(
t
)
=
A
h
(
t
)
+
B
x
(
t
)
.
.
.
(
1
a
)
h'(t)=\mathbf{A}h(t)+\mathbf{B}x(t) ...(1a)
h′(t)=Ah(t)+Bx(t)...(1a)
y
(
t
)
=
C
h
(
t
)
.
.
.
(
1
b
)
y(t)=\mathbf{C}h(t)...(1b)
y(t)=Ch(t)...(1b)
1.2 S4
S4是SSM的进一步优化,SSM到S4的三个步骤:离散化、卷积表示、使用HIPPO算法处理长序列。
1.2.1 离散化
由于SSM使用的都是连续的数据,而我们计算机处理的是数字信号,所以要进行离散化。mamba论文中使用零阶保持算法(zero-order hold)进行离散化。离散之后的公式如下:
h
t
=
A
ˉ
h
t
−
1
+
B
ˉ
x
t
.
.
.
(
2
a
)
h_t=\mathbf{\bar{A}}h_{t-1}+\mathbf{\bar{B}}x_t ...(2a)
ht=Aˉht−1+Bˉxt...(2a)
y
t
=
C
h
t
.
.
.
(
2
b
)
y_t=\mathbf{C}h_t ...(2b)
yt=Cht...(2b)
其中
h
0
=
B
x
0
h_0=\mathbf{B}x_0
h0=Bx0。
1.2.2 卷积表示
利用上述(2a)和(2b)的公式,可以得到下列公式,务必手动推演:
K
ˉ
=
(
C
B
ˉ
,
C
A
ˉ
B
ˉ
,
.
.
.
,
C
A
ˉ
k
B
ˉ
,
.
.
.
)
.
.
.
(
3
a
)
\mathbf{\bar{K}}=(\mathbf{C\bar{B}, C\bar{A}\bar{B}, ..., C\bar{A}^{k}\bar{B}, ...})...(3a)
Kˉ=(CBˉ,CAˉBˉ,...,CAˉkBˉ,...)...(3a)
y
=
x
∗
K
ˉ
.
.
.
(
3
b
)
y=x*\mathbf{\bar{K}}...(3b)
y=x∗Kˉ...(3b)
这就相当于一个卷积公式。所以,它可以像卷积一样并行训练,但是推理时论文中使用递归的方式,可以使得模型更快。
1.2.3 使用HIPPO算法处理长序列
在公式(3a)中我们可以发现一旦token很长,k次方就很大,会导致矩阵相乘的计算量也很大。因此我们可以使用HIPPO算法将矩阵分解成对角阵的形式相乘,计算会方便许多。
1.3 S6
如上图,S6是S4的进一步升级。
x
x
x通过Linear得到
B
,
C
B, C
B,C和
Δ
\Delta
Δ,再通过
A
,
B
A, B
A,B和
Δ
\Delta
Δ得到
A
ˉ
\bar{A}
Aˉ和
B
ˉ
\bar{B}
Bˉ,再通过SSM得到输出
y
y
y。
1.4 Mamba Block
2. 代码使用
2.1 环境
2.2 安装
2.3 使用
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape