当前位置: 首页 > news >正文

windows mamba-ssm环境配置指南

写在前面

本文记录本人在复现mamba-ssm中遇到的所有问题。

创建虚拟环境

无论你的电脑是什么显卡,cuda11还是12,都需要创建虚拟环境。

conda create -n your_env python=3.10.13
conda activate your_env
conda install cudatoolkit==11.8 -c nvidia
conda install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging

安装 triton

triton 在此处下载 https://hf-mirror.com/datasets/ArrayCats/triton-2.0.0-cp310-cp310-win_amd64/tree/main
然后

conda install packaging

下载mamba到本地

mamba-ssm
修改 setup.py 文件

FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"# 修改为:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"

修改 mamba_ssm/ops/selective_scan_interface.py

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)# 修改为:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

最后

pip install .

import selective_scan_cuda 问题

“下载mamba到本地” 中的 代码修改 均已完成,直接注释爆红的 import selective_scan_cuda 即可

# import selective_scan_cuda

其他疑难杂症

cuda available False

此处详细检查自己condo环境的torch是否为cu11.8
在命令行进行下面操作

conda activate your_env
python

粘贴下面命令

import torch
print(torch.__version__)
print(torch.cuda.is_available())

查看是否为False,若为false,卸载torch,重新进行上面步骤。

conda uninstall torch
conda install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118

Numpy问题

RuntimeError: Numpy is not available
AttributeError: module ‘numpy’ has no attribute ‘ndarray’

如果出现Numpy问题,这里没有记录问题图片,读者可以尝试降低numpy为这个版本,可以解决(卸载重装可能会有问题,建议重新安装即可)

conda install numpy=1.23.5
http://www.dtcms.com/a/309952.html

相关文章:

  • 网络层协议IP
  • 运维端口管理闭环:从暴露面测绘到自动化封禁!
  • 【AI问答记录】grafana接收query请求中未携带step参数,后端基于intervalMs和maxDataPoints等参数计算step的逻辑
  • AcWing 897:最长公共子序列 ← 子序列问题(n≤1e3)
  • “数据管理” 一场高风险的游戏
  • 民航领域数据分类分级怎么做?|《民航领域数据分类分级要求》标准解读
  • 第13届蓝桥杯Python青少组中/高级组选拔赛(STEMA)2022年3月13日真题
  • ip去重小脚本
  • uniapp基础 (一)
  • git pull和git fetch的区别
  • Python爬虫实战:研究OpenCV技术构建图像数据处理系统
  • (转)mybatis和hibernate的 缓存区别?
  • (一)React +Ts(vite创建项目)
  • Flask 路由系统:URL 到 Python 函数的映射
  • 嵌入式学习笔记-MCU阶段-DAY10ESP8266模块
  • 第11届蓝桥杯Python青少组中/高级组选拔赛(STEMA)2020年5月30日真题
  • 嵌入式软件 (SW) 设计文件
  • W3D引擎游戏开发----从入门到精通【10】
  • 永洪科技华西地区客户交流活动成功举办!以AI之力锚定增长确定性
  • 视频生成中如何选择GPU或NPU?
  • UE5多人MOBA+GAS 番外篇:同时造成多种类型伤害,以各种属性值的百分比来应用伤害(版本二)
  • 如何理解推理模型
  • 学习:入门uniapp Vue3组合式API版本(17)
  • 2025网络安全指南
  • PyTorch基础——张量计算
  • 考取锅炉司炉工证需要学习哪些专业知识?
  • Altium Designer 22使用笔记(3)---原理图设计
  • Google play上架/更新频繁被拒是什么原因?
  • RabbitMQ 延时队列插件安装与使用详解(基于 Delayed Message Plugin)
  • C++ sort比较规则需要满足严格弱序