windows mamba-ssm环境配置指南
写在前面
本文记录本人在复现mamba-ssm中遇到的所有问题。
创建虚拟环境
无论你的电脑是什么显卡,cuda11还是12,都需要创建虚拟环境。
conda create -n your_env python=3.10.13
conda activate your_env
conda install cudatoolkit==11.8 -c nvidia
conda install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
安装 triton
triton
在此处下载 https://hf-mirror.com/datasets/ArrayCats/triton-2.0.0-cp310-cp310-win_amd64/tree/main
然后
conda install packaging
下载mamba到本地
mamba-ssm
修改 setup.py
文件
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"# 修改为:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
修改 mamba_ssm/ops/selective_scan_interface.py
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)# 修改为:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
最后
pip install .
import selective_scan_cuda 问题
若 “下载mamba到本地” 中的 代码修改 均已完成,直接注释爆红的 import selective_scan_cuda 即可
# import selective_scan_cuda
其他疑难杂症
cuda available False
此处详细检查自己condo环境的torch是否为cu11.8
在命令行进行下面操作
conda activate your_env
python
粘贴下面命令
import torch
print(torch.__version__)
print(torch.cuda.is_available())
查看是否为False,若为false,卸载torch,重新进行上面步骤。
conda uninstall torch
conda install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
Numpy问题
RuntimeError: Numpy is not available
AttributeError: module ‘numpy’ has no attribute ‘ndarray’
如果出现Numpy问题,这里没有记录问题图片,读者可以尝试降低numpy为这个版本,可以解决(卸载重装可能会有问题,建议重新安装即可)
conda install numpy=1.23.5