当前位置: 首页 > news >正文

Tensor Core的MMA与WMMA

文章目录

  • WMMA与MMA区别
    • 使用 WMMA 指令
    • 使用 MMA 指令
  • Tensor Core详解之mma
  • 参考

WMMA与MMA区别

WMMA:Warp Matrix Multiply Accumulate

MMA:Matrix Multiply Accumulate

使用 WMMA 指令

WMMA 是一种较高级别的编程接口,允许开发者以“矩阵片段”为单位进行操作,而不需要关心线程间数据分布的细节。

步骤:

  • 1.加载矩阵(wmma.load)
    使用 wmma.load 操作将矩阵 A、B 和 C 从内存加载到寄存器中。
    每个线程的寄存器中会保存矩阵的一个“片段”(fragment)。
  • 2.执行矩阵乘加运算(wmma.mma)
    使用 wmma.mma 对已加载的矩阵片段执行乘加运算:D=A×B+C
    每个线程的寄存器中会保存结果矩阵 D 的一个片段。
  • 3.存储结果(wmma.store)
    使用 wmma.store 将结果矩阵 D 的片段写回内存。
    也可以将 D 作为 wmma.mma 的参数 C,实现原地更新。

特点:

  • wmma.load 和 wmma.store 会自动处理矩阵元素在线程之间的分布(“隐式处理矩阵元素的组内方式”)。
  • 开发者无需手动管理数据在 warp 内的分布。

使用 MMA 指令

MMA 是一种更底层的接口,提供更高的灵活性,但需要开发者显式管理线程间的数据分布。

步骤:

  • 与 WMMA 类似,也需要 warp 内所有线程协同执行。
  • 不同之处在于:在调用 mma 操作之前,必须显式地将矩阵元素分配到不同线程中
  • MMA 支持更多种类的矩阵格式,包括结构化稀疏矩阵(如“缩流矩阵”,可能是指压缩稀疏行/列格式)。
  • 当矩阵 A 是结构化稀疏矩阵时,可以使用专门的 MMA 变体来提高效率。

特点:

  • 更灵活,支持更多矩阵类型和稀疏格式。
  • 需要开发者手动处理数据在 warp 内的分布。
对比项WMMA (Warp Matrix Multiply Accumulate API)MMA (Matrix Multiply Accumulate Instruction)
定义层级CUDA C++ 提供的 高级 API,封装在 nvcuda::wmma namespace 中GPU 底层 PTX 指令mma.sync 等),直接由硬件执行
使用难度接口类似 C++ 模板库,开发者易于调用(如 wmma::load_matrix_sync / wmma::mma_sync需要写 PTX 汇编或 inline PTX,开发复杂度高
适用对象给普通 CUDA 程序员使用,降低学习成本给编译器后端和高性能库开发者(cuBLAS、CUTLASS)使用
可控粒度支持固定 tile 大小 (16x16x16, 8x32x16, 32x8x16),由 API 限制灵活,直接暴露硬件支持的各种矩阵 tile 形状
可移植性跨 GPU 架构相对稳定,NVIDIA 保证 API 向后兼容随 GPU 架构变化(Volta → Turing → Ampere → Hopper),指令格式可能不同
性能性能接近底层 MMA(API 最终会编译成 MMA 指令)理论上最佳性能,但需要开发者手工调度、对齐、bank conflict 处理
典型使用场景在 CUDA kernel 里写 warp 级 GEMM 或 TensorCore 加速计算实现 cuBLAS、CUTLASS、深度学习框架内核时用
开发成本简单,几行代码即可实现 Tensor Core 计算高,需要理解 GPU 架构、寄存器布局和 PTX 编写

WMMA = 面向开发者的 高级封装 API,写起来友好,效率接近硬件极限;
MMA = 底层 硬件指令,灵活且性能最高,但写起来很复杂,通常由编译器和高性能库使用。

Tensor Core详解之mma

参考:

  • Toy HGEMMM Library using Tensor Cores with MMA/WMMA/CuTe

HGEMM CUDA Kernels in Toy-HGEMM Library

void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);

为啥mma版本的naive不用m16n16k16?
在这里插入图片描述

以hgemm_mma.cu中的hgemm_mma_m16n8k16_naive_kernel函数为例

// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
template <const int MMA_M = 16, const int MMA_N = 8, const int MMA_K = 16>
__global__ void hgemm_mma_m16n8k16_naive_kernel(half *A, half *B, half *C,int M, int N, int K) {const int bx = blockIdx.x;const int by = blockIdx.y;const int NUM_K_TILES = div_ceil(K, MMA_K);constexpr int BM = MMA_M; // 16constexpr int BN = MMA_N; // 8constexpr int BK = MMA_K; // 16__shared__ half s_a[MMA_M][MMA_K]; // 16x16__shared__ half s_b[MMA_K][MMA_N]; // 16x8__shared__ half s_c[MMA_M][MMA_N]; // 16x8const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within blockconst int lane_id = tid % WARP_SIZE;                    // 0~31// s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程const int load_smem_a_m = tid / 2;       // row 0~15const int load_smem_a_k = (tid % 2) * 8; // col 0,8// s_b[16][8], 每行8,每线程load// 8,需要1线程,共16行,需16线程,只需一半线程加载const int load_smem_b_k = tid; // row 0~31, but only use 0~15const int load_smem_b_n = 0;   // col 0const int load_gmem_a_m = by * BM + load_smem_a_m; // global mconst int load_gmem_b_n = bx * BN + load_smem_b_n; // global nif (load_gmem_a_m >= M && load_gmem_b_n >= N)return;uint32_t RC[2] = {0, 0};#pragma unrollfor (int k = 0; k < NUM_K_TILES; ++k) {// gmem_a -> smem_a,A从global memory搬运到shared memoryint load_gmem_a_k = k * BK + load_smem_a_k; // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) =(LDST128BITS(A[load_gmem_a_addr]));// gmem_b -> smem_b:B从global memory搬运到shared memoryif (lane_id < MMA_K) {int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) =(LDST128BITS(B[load_gmem_b_addr]));}__syncthreads();uint32_t RA[4];uint32_t RB[2];// ldmatrix for s_a, ldmatrix.trans for s_b. 从shared memory搬运到寄存器// s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)]uint32_t load_smem_a_ptr =__cvta_generic_to_shared(&s_a[lane_id % 16][(lane_id / 16) * 8]);LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr);uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[lane_id % 16][0]);LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr);HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0],RC[1]);__syncthreads();}// s_c[16][8],// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html// #matrix-fragments-for-mma-m16n8k16-with-floating-point-type// [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]LDST32BITS(s_c[lane_id / 4][(lane_id % 4) * 2]) = LDST32BITS(RC[0]);LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]);__syncthreads();// store s_c[16][8]if (lane_id < MMA_M) {// store 128 bits per memory issue.int store_gmem_c_m = by * BM + lane_id;int store_gmem_c_n = bx * BN;int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0]));}
}

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

参考

  • 【CUDA进阶】MMA分析Bank Conflict与Swizzle(已完结)
  • 【CUDA进阶】Tensor Core实战教程(已完结)
  • 【CUDA进阶】Tensor Core实战教程(上)
  • 【CUDA进阶】Tensor Core实战教程(下)
http://www.dtcms.com/a/428682.html

相关文章:

  • 大模型计算事实标准--CUDA
  • 网站导航页面模板网络与新媒体就业方向及前景
  • 公司网站 备案南通市建设监理协会网站
  • 洋县住房和城乡建设管理局网站网站中宣传彩页怎么做的
  • 网站群建设进展情况汇报郑州网站开发技术
  • C语言数据结构-堆
  • 网站建设开头营销型网站推广公司
  • 公考刷题资源合集
  • 网站建设与管理专业好找工作吗通辽网站建设
  • QCustomPlot 特定图表类型实战
  • python 将关键数据标注在png图片里
  • python加速方法 对比 numba numb.cuda triton pycuda cupy
  • 常州天狼网站建设二手物品交换网站建设
  • 关于Java的几个小问题
  • 青岛手机网站建设手工制作月饼
  • 国外的哪个网站可以做跳转青岛网站推广途径
  • 湖北省和住房建设厅官方网站山东泰山新闻
  • 营销网站定制公司台州网站设计哪家好
  • C++——基础
  • 回顾首尔 KBW 2025,Sui 的创新与联结周
  • 2025CCPC郑州邀请赛暨河南省赛 B. 随机栈 II 题解
  • 珠海网站建设公商城二次开发
  • 合肥市蜀山区做个网站多少钱电子商务网页制作是什么
  • 做网站推广需要多少费用one dirve做网站
  • 京东网站的建设与发展前景上海闵行中心医院
  • 无锡网站建设要求手机设计软件app推荐
  • 网站开发e r图无极网页游戏
  • 网站建设与功能模块wordpress appkey 插件
  • 鸿鹄网站建设网站的建设费用
  • 硅胶东莞网站建设公司注册后怎么做网站