当前位置: 首页 > news >正文

CuTe C++ 简介01,从示例开始

这里先仅仅关注 C++ 层的介绍,python DSL 以后再说。

在 ubuntu 22.04 X64 中,RTX 5080

1. 环境搭建

1.1 安装 cuda

1.2 下载源码

git clone 	https://github.com/NVIDIA/cutlass.git

1.3 编译

mkdir build/
cmake .. -DCUTLASS_NVCC_ARCHS="120" -DCMAKE_BUILD_TYPE="Debug"
make -j20

如果内存不是太大,cpu核心不是太多,那么可以使用 make -j8 等较小的编译线程数量。

查看生成的文件:

2. 调试一个示例

cutlass/examples/cute/tutorial/sgemm_1.cu

2.1 先跑起来

build/examples/cute/tutorial$ ./cute_tutorial_sgemm_1 

2.2 调试一下

启动调试器,启动程序,进入 main 函数:

build/examples/cute/tutorial$ cuda-gdb ./cute_tutorial_sgemm_1
(cuda-gdb) layout src
(cuda-gdb) start

接下来定义了矩阵A、B、C:

其中 初始化 gpu 的函数 cute::device_init(0) 的代码内容如下:

其中,thrust::host_vector<> () 是使用 c++ new表达式申请的内存。

而 thrust::device_vector<> () 是使用 cudaMalloc 申请的显存:

赌一把:

(cuda-gdb) break  cudaMalloc
(cuda-gdb) continue 
(cuda-gdb) bt

可以看到停在了 cudaMalloc() 函数上。

会调用到 gemm_device():

接下来分析一下这个 kernel 的实现:

这个kernel 代码比较长,我们一段段地分析:


template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,Alpha alpha, Beta beta)
{using namespace cute;// PreconditionsCUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)static_assert(is_static<AThreadLayout>::value);static_assert(is_static<BThreadLayout>::value);static_assert(is_static<CThreadLayout>::value);CUTE_STATIC_ASSERT_V(size(tA) == size(tB));                          // NumThreadsCUTE_STATIC_ASSERT_V(size(tC) == size(tA));                          // NumThreadsCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{});  // BLK_N / THR_NCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{});  // BLK_N / THR_Nstatic_assert(is_static<ASmemLayout>::value);static_assert(is_static<BSmemLayout>::value);static_assert(is_static<CSmemLayout>::value);CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MKCUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NKCUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN

CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});

为 ProblemShape shape_MNK 类型,

rank

未完待续

。。。。

http://www.dtcms.com/a/354736.html

相关文章:

  • wpf之ListBox
  • 一个客户端直接掉线或断点,服务器怎么快速识别
  • 通过代码认识 CNN:用 PyTorch 实现卷积神经网络识别手写数字
  • audioMAE模型代码分析
  • 循环神经网络——pytorch实现循环神经网络(RNN、GRU、LSTM)
  • 深度学习——卷积神经网络(PyTorch 实现 MNIST 手写数字识别案例)
  • SpringBoot项目使用Liquibase 数据库版本管理
  • Day16_【机器学习—KNN算法】
  • IDA Pro 逆向分析快捷键大全及核心用法详解
  • 【Day 35】Linux-Mysql错误总结
  • 微信小程序对接EdgeX Foundry详细指南
  • 导入文件允许合并表格
  • 上海控安:汽车API安全-风险与防护策略解析
  • Elasticsearch数据迁移方案深度对比:三种方法的优劣分析
  • 领悟8种常见的设计模式
  • 74HC595芯片简析
  • OpenCV 基础操作实战指南:从图像处理到交互控制
  • Ubuntu有限网口无法使用解决方法
  • 企业级AI应用场景的核心特征
  • Fluent Bit针对kafka心跳重连机制详解(上)
  • DA14531(Cortex-M0+)之Wake-up Interrupt Controller (WIC)
  • JumpServer开源堡垒机:统一访问入口 + 安全管控 + 操作审计
  • 从0开始搭建一个前端项目(vue + vite + typescript)
  • 第6篇:链路追踪系统 - 分布式环境下的请求跟踪
  • 【STM32】G030单片机的窗口看门狗
  • 从协作机器人到智能协作机器人:工业革命的下一跳
  • Fluent Bit针对kafka心跳重连机制详解(下)
  • KubeBlocks For MySQL 云原生设计分享
  • Logstash数据迁移之mysql-to-kafka.conf详细配置
  • 卷积神经网络(CNN)搭建详解