mamba,mamba2环境搭建
mamba和mamba2安装步骤的相关代码
conda create -n mamba_test python=3.10
conda activate mamba_test
conda install cudatoolkit=11.8 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/
pip install mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install triton==2.1.0
pip install numpy==1.22.4
对应的whl文件的下载地址:
mamba_ssm下载
causal_conv1d下载
可以运行的mamba和mamba2测试代码:
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
print("Mamba result", y.shape)
assert y.shape == x.shape
import torch
from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 512
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
# make sure d_model * expand / headdim = multiple of 8
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=64, # default 64
).to("cuda")
y = model(x)
print("Mamba2 result", y.shape)
assert y.shape == x.shape
可以参考的调试步骤:
Mamba-2 Error: ‘NoneType‘ object has no attribute ‘causal_conv1d_fwd‘
mamba_ssm和causal-conv1d安装教程