当前位置: 首页 > news >正文

【CUDA进阶】Tensor Core实战教程(下)

目录

    • 前言
    • 1. WMMA(Warp Matrix Multiply Accumulate)
    • 2. hgemm_v1_wmma_m16n16k16_naive_kernel
    • 3. hgemm_v2_wmma_m16n16k16_mma4x2_kernel
    • 4. hgemm_v3_wmma_m16n16k16_mma4x2_warp2x4_kernel
    • 5. hgemm_v4_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel
    • 结语
    • 下载链接
    • 参考

前言

学习 UP 主 比飞鸟贵重的多_HKL 的 【CUDA进阶】Tensor Core实战教程(已完结) 视频,记录下个人学习笔记,仅供自己参考😄

refer 1:【CUDA进阶】Tensor Core实战教程(已完结)

refer 2:https://github.com/xlite-dev/LeetCUDA

refer 3:https://github.com/Bruce-Lee-LY/cuda_hgemm

refer 4:https://chatgpt.com

1. WMMA(Warp Matrix Multiply Accumulate)

在上篇文章中我们了解了 Tensor Core 的基本原理,并学习了利用 cuBLAS 库的 cublasGemmEx 接口调用 Tensor Core 来实现矩阵乘法。此外在上篇文章中我们还详细分析了 cublasGemmEx 接口的参数,一些常见的概念例如主序、主维、转置等我们都有提及

虽然 cuBLAS 已经把矩阵乘法运算(GEMM)的维度、数据布局、调度策略等都封装好了,调用也简单,但它也因此 “黑盒” 化,也就是我们无法在其内部插入一些其它的运算

例如在我们讲 Flash Attention 的时候,有提到它主要是将原始 Attention 中的 Softmax、Mask、MatMul 等算子 fusion 融合在一起,不对中间结果缓存,减少 HBM 的访问。而 cuBLAS 提供的接口只能做纯粹的矩阵乘,不支持中间插入 Softmax、Mask 等步骤,无法实现 Flash Attention 中的算子融合优化

所以我们就需要学习一些更底层一点的 API 接口,而 NVIDIA 通过 CUDA C++ WMMA(Warp Matrix Multiply Accumulate)API 向外提供了Tensor Core 在 Warp 级别上的计算操作支持,我们下面就来学习下如何使用 WMMA 这个 API

下面我们跟随 NVIDIA 官方文档一起来看看关于 WMMA API 的一些说明

Warp Matrix Functions(Warp 级矩阵函数如 WMMA)利用 Tensor Core 为形如 D=A*B+C 的矩阵乘加操作提供硬件加速支持。这些操作支持混合精度浮点数据,该功能仅在 compute capability≥7.0 的设备上可用

函数与类型定义如下(所有接口都位于命名空间 nvcuda::wmma):

// 定义一个矩阵片段(fragment),由 warp 中所有线程共同持有
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;// 主要 API
void load_matrix_sync  (fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync  (fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync (T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment     (fragment<...> &a, const T& v);
void mma_sync          (fragment<...> &d, const fragment<...> &a,const fragment<...> &b, const fragment<...> &c,bool satf=false);

其中:

  • fragment:Tensor Core 数据存储类,支持 matrix_amatrix_baccumulator
  • load_matrix_sync:Tensor Core 数据加载 API,支持将矩阵数据从 global memory 或 shared memory 加载到 fragment
  • store_matrix_sync:Tensor Core 结果存储 API,支持将计算结果从 fragment 存储到 global memory 或 shared memory
  • fill_fragment:fragment 填充 API,支持常数值填充
  • mma_sync:Tensor Core 矩阵乘计算 API,支持 D = AB + C 或者 C = AB +C

下面示例演示了在单 warp 内完成 16×16×16 的矩阵乘加:

#include <mma.h>
using namespace nvcuda;__global__ void wmma_ker(half *a, half *b, float *c) {// Declare the fragmentswmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag;wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;// Initialize the output to zerowmma::fill_fragment(c_frag, 0.0f);// Load the inputswmma::load_matrix_sync(a_frag, a, 16);wmma::load_matrix_sync(b_frag, b, 16);// Perform the matrix multiplicationwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);// Store the outputwmma::store_matrix_sync(c, c_frag, 16, wmma::mem_row_major);
}

该核函数依次完成 fragment 填充、加载、乘加及存储四步,由 warp 内所有线程同步执行

更多细节大家可以查看 CUDA 官方文档:https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma

下面我们就正式开始学习利用 WMMA API 来实现 hgemm 半精度矩阵乘法的计算

2. hgemm_v1_wmma_m16n16k16_naive_kernel

我们先来看 V1 版本的原生实现,也就是单纯调用 wmma 接口不做任何优化来实现 hgemm:

在这里插入图片描述

我们要实现的是半精度矩阵乘法 Cm×n=Am×k×Bk×nC_{m\times n}=A_{m\times k} \times B_{k \times n}Cm×n=Am×k×Bk×n,在我们的示例中 m=512,n=2048,k=1024m=512,n=2048,k=1024m=512,n=2048,k=1024

在 V1 版本中每个 block 启动 32 个线程即一个 warp 来完成 C 中 16x16 的输出 tile,也就是每个 block 完成的是 16x1024 * 1024x16 的矩阵乘法计算

前面我们提到 wmma 是 warp 级别的接口且实现的是 16x16x16 的矩阵乘积,因此如果要实现 16x1024 * 1024x16 就需要沿着 K 维度(步长等于 WMMA_K = 16)不断去调用 wmma 指令乘积,每次指令完成的是 16x16 * 16x16 的计算,如上图所示

实现代码如下:

#include <iostream>
#include <cuda_runtime.h>
#include "common/tester.h"
#include "common/common.h"using namespace nvcuda;// only 1 warp per block(32 threads), m16n16k16. A, B, C: all row_major
template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16>
__global__ void hgemm_wmma_m16n16k16_naive_kernel(half* A, half* B, half* C, int M, int N, int K){const int NUM_K_TILES = div_ceil(K, WMMA_K);const int load_gmem_a_m = blockIdx.y * WMMA_M;const int load_gmem_b_n = blockIdx.x * WMMA_N;if(load_gmem_a_m >= M && load_gmem_b_n >= N){return;}wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;wmma::fill_fragment(C_frag, 0.0);#pragma unrollfor(int k = 0; k < NUM_K_TILES; ++k){wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;wmma::load_matrix_sync(A_frag, A + load_gmem_a_m * K + k * WMMA_K, K);wmma::load_matrix_sync(B_frag, B + (k * WMMA_K) * N + load_gmem_b_n, N);wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);__syncthreads();}wmma::store_matrix_sync(C + load_gmem_a_m * N + load_gmem_b_n, C_frag, N, wmma::mem_row_major);
}void hgemm_wmma_m16n16k16_naive(half* A, half* B, half* C, int M, int N, int K){constexpr int WMMA_M = 16;constexpr int WMMA_N = 16;constexpr int WMMA_K = 16;dim3 block(32);dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));hgemm_wmma_m16n16k16_naive_kernel<WMMA_M, WMMA_N, WMMA_K><<<grid, block>>>(A, B, C, M, N, K);
}int main(int argc, char* argv[]){Tester tester(512, 2048, 1024, 1, 10, 100, true);tester.evaluate(hgemm_wmma_m16n16k16_naive, "hgemm_wmma_m16n16k16_naive");return 0;
}

下面我们从整体到细节,分几个部分来剖析 hgemm_wmma_m16n16k16_naive_kernel 这个 kernel 是如何通过 WMMA API 在 Tensor Core 上完成 16x16x16 的 hgemm 运算的:(from ChatGPT)

1. 网格和线程组织

  • Block 大小
dim3 block(32);

每个 block 只启用 32 个线程,也就是正好一个 warp。这意味着一个 block(即一个 warp)负责计算一个 16x16 的输出 tile

  • Grid 大小
dim3 grid(div_ceil(N, WMMA_N), div_ceil(M, WMMA_M));
  • grid.x = ceil(N/16):沿输出矩阵列方向(N)划分 tile
  • grid.y = ceil(M/16):沿输出矩阵行方向(M)划分 tile

因此,网格中的每个 block(bx, by)负责输出矩阵 C 上坐标为 (row=by×16,col=bx×16)(\mathrm{row}=by \times 16, \, \mathrm{col}=bx \times 16)(row=by×16,col=bx×16) 处的 16x16 子块

2. 核函数声明

template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16>
__global__ void hgemm_wmma_m16n16k16_naive_kernel(half* A, half* B, half* C, int M, int N, int K) {
  • Tensor Core 配置:使用 16x16x16 的矩阵块
  • 内存布局:所有矩阵均为行主序(row-major)

3. 计算全局内存中的加载位置

const int NUM_K_TILES = div_ceil(K, WMMA_K);  // K 维度分块数
const int load_gmem_a_m = blockIdx.y * WMMA_M;  // A 矩阵行起始位置
const int load_gmem_b_n = blockIdx.x * WMMA_N;  // B 矩阵列起始位置if(load_gmem_a_m >= M && load_gmem_b_n >= N){return;
}
  • 分块策略
    • blockIdx.y 控制输出矩阵 行方向分块(M 维度)
    • blockIdx.x 控制输出矩阵 列方向分块(N 维度)
    • 每个 block 计算一个 16x16 的结果块
  • 边界检查
    • 跳过无效 block(当矩阵尺寸非 16 倍数时)

4. Fragment 的声明与初始化

wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;
wmma::fill_fragment(C_frag, 0.0f);
  • Accumulator fragment
    • 类型参数 wmma::accumulator, M, N, K, half 表示这是一个用来存放累加结果的 fragment,元素类型为 half
    • fill_fragmentC_frag 的每一个元素都初始化为 0

5. 分块遍历 K 维

const int NUM_K_TILES = div_ceil(K, WMMA_K);
for(int k = 0; k < NUM_K_TILES; ++k) {}
  • 总共将 K 维分成 NUM_K_TILES = ceil(K/16) 段,每次循环处理 16 列/行的乘加

6. 加载 A、B 子矩阵到 fragment

wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;wmma::load_matrix_sync(A_frag,A + load_gmem_a_m * K + k * WMMA_K,K
);
wmma::load_matrix_sync(B_frag,B + (k * WMMA_K) * N + load_gmem_b_n,N
);
  • 指针偏移
    • load_gmem_a_m = by * 16load_gmem_b_n = bx * 16
    • A 矩阵子块首地址:
      Ablock=A+(by×16)×K+k×16A_{\mathrm{block}} = A + (by \times 16) \times K + k \times 16Ablock=A+(by×16)×K+k×16
    • 这里 K 是 A 的行主序 leading dimension(每行长度)、
    • B 矩阵子块首地址:
      Bblock=B+(k×16)×N+bx×16B_{\mathrm{block}} = B + (k \times 16) \times N + bx \times 16Bblock=B+(k×16)×N+bx×16
    • 这里 N 是 B 的行主序 leading dimension
  • row major
    • 两个 load_matrix_sync 都指定了 wmma::row_major,告诉 WMMA 按行主序去加载数据
  • Warp 内部协同
    • load_matrix_sync 在底层是一条 warp 级指令,32 个线程协同完成将 16x16 half 数据装载到 fragment 中

7. Tensor Core 乘加

wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);
__syncthreads();
  • mma_sync
    • 执行一次 Cout=A×B+CinC_{\mathrm{out}} = A \times B + C_{\mathrm{in}}Cout=A×B+Cin 的 16x16x16 half-precision 运算,所有 32 个线程共同触发一条 Tensor Core 指令
  • 同步屏障
    • __syncthreads() 确保本 warp(也是本 block,block 只有一个 warp)中所有线程完成 mma 后再进入下一轮循环

8. 将累加结果写回全局内存

wmma::store_matrix_sync(C + load_gmem_a_m * N + load_gmem_b_n,C_frag,N,wmma::mem_row_major
);
  • C 子块首地址:
    Cblock=C+(by×16)×N+(bx×16)C_{\mathrm{block}} = C + (by \times 16) \times N + (bx \times 16)Cblock=C+(by×16)×N+(bx×16)
  • store_matrix_sync
    • C_frag 按行主序写回到全局内存

hgemm_v1 版本的实现还是比较简单,每个 block 仅包含一个 warp 负责 C 矩阵中一个 16x16 的 tile,通过 load_matrix_syncmma_syncstore_matrix_sync 这样三步标准流程,多次迭代累加完成完整的矩阵乘加

nsight compute 的性能和带宽测试结果如下:

优化手段矩阵维度GridBlock耗时(us)Memory [%]DRAM Throughout(%)Compute(SM)[%]
baseline(ampere_h1688gemm_128x128_ldg8_stages_32x1_nn)m=512,n=2048,k=1024--83.4266.7141.6169.67
hgemm_v1_wmma_m16n16k16_naive_kernelm=512,n=2048,k=1024(128,32)(32)618.1483.7220.1728.29

Note:测试设备 NVIDIA RTX3060,CUDA-11.6,launch 次数 2000

Memory Chart 内存图如下所示:

在这里插入图片描述

关于 nsight compute 的简单使用可以参考:【CUDA调优指南】合并访存

3. hgemm_v2_wmma_m16n16k16_mma4x2_kernel

V1 版本中每个 block 仅一个 warp,且每次循环都要从全局内存中加载子块,导致带宽浪费和低并发。我们可以考虑让每个 block 开启更多的线程,计算更大的输出 tile 来提高 occupancy,此外还可以利用共享内存降低全局带宽消耗:

在这里插入图片描述

在 V2 版本中,每个 block 开启的线程数是 256,也就是 8 个 warp,一个 block 负责 64x32 的输出子矩阵,8 个 warp 按 (warp_m, warp_n) 网格(4x2)分工,每个 warp 依旧负责计算一个 16x16 的子 tile

每一次迭代都会从 global memory 中加载矩阵元素到 shared memory 中,其中矩阵 A 需要加载 64x16 的子 tile,矩阵 B 需要加载 16x32 的子 tile,这些子 tile 再分配给各个 warp 进行 wmma 指令运算得到最终 64x32 的输出子 tile 结果

实现代码如下:

#include <iostream>
#include <cuda_runtime.h>
#include "common/tester.h"
#include "common/common.h"#define WARP_SIZE 32
#define LDST32BITS(value) (reinterpret_cast<half2 *>(&(value))[0])
#define LDST64BITS(value) (reinterpret_cast<float2 *>(&(value))[0])using namespace nvcuda;// m16n16k16 wmma  + tile MMA with smem,  A, B, C: all row_major
template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16,const int WMMA_TILE_M = 4, const int WMMA_TILE_N = 2>
__global__ void hgemm_wmma_m16n16k16_mma4x2_kernel(half* A, half* B, half* C, int M, int N, int K){// 256 thread(8 warps) per block.const int bx = blockIdx.x;const int by = blockIdx.y;const int NUM_K_TILES = div_ceil(K, WMMA_K);constexpr int BM = WMMA_M * WMMA_TILE_M;       // 16x4=64constexpr int BN = WMMA_N * WMMA_TILE_N;       // 16x2=32constexpr int BK = WMMA_K;                     // 16__shared__ half s_a[BM][BK], s_b[WMMA_K][BN];  // 64x16x2=2KB, 16x32x2=1KB// 要保证相同的 warp 下 thread 执行相同指令// warp_id 0 -> warp_m 0, warp_n 0// warp_id 1 -> warp_m 0, warp_n 1// warp_id 2 -> warp_m 1, warp_n 0// warp_id 3 -> warp_m 1, warp_n 1const int tid = threadIdx.y * blockDim.x + threadIdx.x;const int warp_id = tid / WARP_SIZE;  // 0~7 warp_id within blockconst int lane_id = tid % WARP_SIZE;  // 0~31const int warp_m = warp_id / 2;       // 0,1,2,3const int warp_n = warp_id % 2;       // 0,1// 256 线程分别 load s_a=64x16, s_b=16x32// 64x16/256=4, half4, 16x32/256=2, half2// s_a, 64x16, 每个线程 load 4 half, 每行需要 4 线程, 64 行, 共 256 线程const int load_smem_a_m = tid / 4;        // 0~63const int load_smem_a_k = (tid % 4) * 4;  // 0,4,8,12// s_b, 16x32, 每个线程 load 2 half, 每行需要 8 线程, 32 行, 共 256 线程const int load_smem_b_k = tid / 16;                 // 0~16const int load_smem_b_n = (tid % 16) * 2;           // 0,2,4,...,30const int load_gmem_a_m = by * BM + load_smem_a_m;  // global mconst int load_gmem_b_n = bx * BN + load_smem_b_n;  // global nif(load_gmem_a_m >= M && load_gmem_b_n >= N){return;}wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;wmma::fill_fragment(C_frag, 0.0);#pragma unrollfor(int k = 0; k < NUM_K_TILES; ++k){int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;// // 64 bits sync memory issues gmem_a -> smem_a// LDST64BITS(s_a[load_smem_a_m][load_smem_a_k]) = LDST64BITS(A[load_gmem_a_addr]);// // 32 bits sync memory issues gmem_b -> smem_b// LDST32BITS(s_b[load_smem_b_k][load_smem_b_n]) = LDST32BITS(B[load_gmem_b_addr]);// 64 bits sync memory issues gmem_a -> smem_a.LDST64BITS(s_a[load_smem_a_m][load_smem_a_k]) =(LDST64BITS(A[load_gmem_a_addr]));// 32 bits sync memory issues gmem_b -> smem_b.LDST32BITS(s_b[load_smem_b_k][load_smem_b_n]) =(LDST32BITS(B[load_gmem_b_addr]));        __syncthreads();wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;wmma::load_matrix_sync(A_frag, &s_a[warp_m * WMMA_M][0], BK);  // BM*BK, BK=WMMA_Kwmma::load_matrix_sync(B_frag, &s_b[0][warp_n * WMMA_N], BN);  // BK=BN, BK=WMMA_Kwmma::mma_sync(C_frag, A_frag, B_frag, C_frag);__syncthreads();}const int store_gmem_a_m = by * BM + warp_m * WMMA_M;const int store_gmem_a_n = bx * BN + warp_n * WMMA_N;wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag, N, wmma::mem_row_major);
}void hgemm_wmma_m16n16k16_mma4x2(half* A, half* B, half* C, int M, int N, int K){constexpr int WMMA_M = 16;constexpr int WMMA_N = 16;constexpr int WMMA_K = 16;constexpr int WMMA_TILE_M = 4;constexpr int WMMA_TILE_N = 2;dim3 block(256);dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M));hgemm_wmma_m16n16k16_mma4x2_kernel<WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N><<<grid, block>>>(A, B, C, M, N, K);
}int main(int argc, char* argv[]){Tester tester(512, 2048, 1024, 1, 10, 100, true);tester.evaluate(hgemm_wmma_m16n16k16_mma4x2, "hgemm_wmma_m16n16k16_mma4x2");return 0;
}

下面我们详细拆解下 V2 版本的 kernel:(from ChatGPT)

1. 网格和线程组织

  • Block 大小
dim3 block(256);

每个 block 只启用 256 个线程,也就是 8 个 warp。每个 warp 各自负责计算一个 16x16 的输出 tile

  • Grid 大小
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M));
  • grid.x = ceil(N/32):沿输出矩阵列方向(N)划分 tile
  • grid.y = ceil(M/64):沿输出矩阵行方向(M)划分 tile

每个 block 计算一个 64x32 的大 tile,分给 4x2=8 个 warp 细分计算

2. 宏定义

#define WARP_SIZE 32
#define LDST32BITS(value) (reinterpret_cast<half2 *>(&(value))[0])
#define LDST64BITS(value) (reinterpret_cast<float2 *>(&(value))[0])
  • WARP_SIZE:一个 warp 的线程数(32)
  • LDST32BITS/64BITS:利用 half2(32 bit)和 float2(64 bit)做向量化读写,以提升全局与共享内存的带宽利用

3. 核函数声明

template<const int WMMA_M = 16,const int WMMA_N = 16,const int WMMA_K = 16,const int WMMA_TILE_M = 4,const int WMMA_TILE_N = 2>
__global__ void hgemm_wmma_m16n16k16_mma4x2_kernel(half* A, half* B, half* C, int M, int N, int K) {
  • 模板参数
    • WMMA_M/N/K=16:Tensor Core 支持的 16x16x16 tile
    • WMMA_TILE_N=4, WMMA_TILE_N=2:在 M/N 方向各自再做细粒度 tiling,每个 block 覆盖 M 方向 16x4=64、N 方向 16x2=32

4. Tile 尺寸和共享内存计算

// 1) 计算 tile 尺寸、分段数,申请 Shared Memory
const int bx = blockIdx.x
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M;   // 64
constexpr int BN = WMMA_N * WMMA_TILE_N;   // 32
constexpr int BK = WMMA_K;                 // 16
__shared__ half s_a[BM][BK];               // 64×16
__shared__ half s_b[BK][BN];               // 16×32
  • bx, by:确定本 block 覆盖的 C 大 tile 在 (by×BM, bx×BN) 处
  • NUM_K_TILES:K 维需切分的段数
  • s_a, s_b:分别缓存 A 的 64x16 和 B 的 16x32,大小约 3KB

5. 线程组织优化

// 2) 线程布局:256 线程 = 8 warps
int tid     = threadIdx.y * blockDim.x + threadIdx.x;
int warp_id = tid / WARP_SIZE;            // [0..7]
int lane_id = tid % WARP_SIZE;            // [0..31]
int warp_m  = warp_id / WMMA_TILE_N;      // ∈[0..3]
int warp_n  = warp_id % WMMA_TILE_N;      // ∈[0..1]
  • tid:block 内唯一线程索引 (0…255)
  • warp_id:当前线程处于第几个 warp (0…7)
  • warp_m/warp_n:决定这个 warp 负责的 16x16 子 tile 在大 tile 中的位置

6. 共享内存加载索引计算

// 256 线程分别 load s_a=64x16, s_b=16x32
// 64x16/256=4, half4, 16x32/256=2, half2
// s_a, 64x16, 每个线程 load 4 half, 每行需要 4 线程, 64 行, 共 256 线程
const int load_smem_a_m = tid / 4;        // 0~63
const int load_smem_a_k = (tid % 4) * 4;  // 0,4,8,12
// s_b, 16x32, 每个线程 load 2 half, 每行需要 8 线程, 32 行, 共 256 线程
const int load_smem_b_k = tid / 16;                 // 0~15
const int load_smem_b_n = (tid % 16) * 2;           // 0,2,4,...,28,30
const int load_gmem_a_m = by * BM + load_smem_a_m;  // global m
const int load_gmem_b_n = bx * BN + load_smem_b_n;  // global n
  • load_smem_…:把 64x16 和 16x32 区块,按每线程 2–4 个 half 的粒度分摊到 256 线程上

7. Fragment 的声明与初始化

// 3) 为本 warp 准备累加 fragment
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag;
wmma::fill_fragment(C_frag, 0.0f);
  • 初始化一个 16x16 的累加寄存器 fragment,清零

那大家可能有所困惑,为什么我们有 8 个 warp,但累加 fragment 即 C_frag 只声明了一次呢?🤔

在 V2 里,C_frag 虽然只在代码中写了一份声明,但它其实是 “每个线程都有一份” 的局部变量,编译后会分配到每个线程的寄存器里。WMMA 的 load/mma_sync/store 系列接口是 warp-synchronous 的 intrinsic,调用时会让同一个 warp(32 个线程)内的寄存器协同构成这个 warp 的完整 fragment。不同 warp 的线程各自维护自己那份寄存器组,互不干扰,所以在代码中只需要写一次声明就能满足 8 个 warp 同时运行的需求

8. 核心计算逻辑

// 4) K 维分段迭代:先 load 到 SMEM,再每 warp 调用 WMMA
for(int k = 0; k < NUM_K_TILES; ++k){// 4.1) 每线程算出自己要 load 的全局内存坐标int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;// 4.2 向量化加载到 SMEM// A: 用 64-bit 载入 half4LDST64BITS(s_a[load_smem_a_m][load_smem_a_k]) = LDST64BITS(A[load_gmem_a_addr]);// B: 用 32-bit 载入 half2LDST32BITS(s_b[load_smem_b_k][load_smem_b_n]) = LDST32BITS(B[load_gmem_b_addr]);__syncthreads();
  • LDSTxxBITS:一次装入两个或四个 half,提高访存效率
  • __syncthreads() 确保全部数据加载完成
    // 4.3) 本 warp 从 SMEM 装载自己的 16×16 子 tilewmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag;wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag;// A 子 tile 在 s_a 的第 warp_m*16 行起始,leading dim=BK=16wmma::load_matrix_sync(A_frag, &s_a[warp_m * WMMA_M][0], BK);// B 子 tile 在 s_b 的第 warp_n*16 列起始,leading dim=BN=32wmma::load_matrix_sync(B_frag, &s_b[0][warp_n * WMMA_N], BN);// 4.4) Tensor Core 乘加wmma::mma_sync(C_frag, A_frag, B_frag, C_frag);__syncthreads();}
  • load_matrix_sync:从共享内存读入 fragment,只不过这里的数据取自 SMEM 而非 GMEM
  • mma_sync:触发一次 16x16x16 half GEMM

9. 将累加结果写回全局内存

// 5) 写回全局 C
const int store_gmem_a_m = by * BM + warp_m * WMMA_M;
const int store_gmem_a_n = bx * BN + warp_n * WMMA_N;
wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag, N, wmma::mem_row_major);
  • 每个 warp 将 16x16 结果写回对应全局内存位置

V2 保留了 V1 中调用 WMMA API 的核心三步(load_matrix_syncmma_syncstore_matrix_sync),但通过 “二级 tiling + 向量化加载” 在共享内存中复用数据,大幅提升了内存带宽利用率和硬件并行度

相比 V1,V2 主要在以下几个方面做了优化:

1. 更大粒度的 Block 级 Tile

  • V1 每个 Block 只计算一个 16x16 的子块
  • V2 将 Block 扩展到 64x32(由 4x2 个 16x16 子块组成),一个 Block 内部 8 个 warp 并行协作,提升了线程并发度和硬件利用率

2. 二级 Tiling 结构

  • 在 K 维沿用 WMMA 的 16 列分段
  • 在 M/N 方向又做了更粗的 tiling(BM=16x4,BN=16x2),将多个子块一起加载到共享内存,再由各 warp 复用

3. Shared Memory 复用

  • V1 中每个 warp 在每个 K 分段都会从全局内存重新加载 16x16 的 A、B 矩阵
  • V2 通过把整块 64x16(A)和 16x32(B)一次性 load 进 SMEM,让同一个 Block 内所有 warp、同一 K 段内的多次 tensor core 调用都可直接复用,大幅减少全局访存次数

4. 向量化聚合加载

  • V1 只有 load_matrix_sync(warp 级)一个读内存通道
  • V2 利用 256 个线程配合 half2(32 bit)/ float2(64 bit)做合并读写,先并行把大块数据搬到 SMEM,再用 WMMA API 装载,显著提升带宽利用率

nsight compute 的性能和带宽测试结果如下:

优化手段矩阵维度GridBlock耗时(us)Memory [%]DRAM Throughout(%)Compute(SM)[%]
baseline(ampere_h1688gemm_128x128_ldg8_stages_32x1_nn)m=512,n=2048,k=1024--83.4266.7141.6169.67
hgemm_v1_wmma_m16n16k16_naive_kernelm=512,n=2048,k=1024(128,32)(32)618.1483.7220.1728.29
hgemm_v2_wmma_m16n16k16_mma4x2_kernelm=512,n=2048,k=1024(64,8)(256)309.0688.3926.3828.05

Memory Chart 内存图如下所示:

在这里插入图片描述

4. hgemm_v3_wmma_m16n16k16_mma4x2_warp2x4_kernel

在 V2 版本中每个 warp 依旧只计算一个 16x16 的子块,我们可以让每个 warp 同时累加多个子 tile(如 2x4 块)来提高算/访比:

在这里插入图片描述

在 V3 版本中,每个 block 开启的线程数依旧是 256,也就是 8 个 warp,注意此时一个 block 负责 128x128 的输出子矩阵。8 个 warp 依旧按 (warp_m, warp_n) 网格(4x2)分工,但每个 warp 负责计算 8 个 16x16 的子 tile

代码如下:

#include <iostream>
#include <cuda_runtime.h>
#include "common/tester.h"
#include "common/common.h"#define WARP_SIZE 32
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])using namespace nvcuda;// m16n16k16 wmma  + tile MMA with smem,  A, B, C: all row_major
template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16,const int WMMA_TILE_M = 4, const int WMMA_TILE_N = 2,const int WARP_TILE_M = 2, const int WARP_TILE_N = 4>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_kernel(half* A, half* B, half* C, int M, int N, int K){// 256 thread(8 warps) per block.const int bx = blockIdx.x;const int by = blockIdx.y;const int NUM_K_TILES = div_ceil(K, WMMA_K);constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M;  // 16x4*2=128constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N;  // 16x2*4=128constexpr int BK = WMMA_K;                              // 16__shared__ half s_a[BM][BK], s_b[BK][BN];               // 16x128x2=4KB// 要保证相同的 warp 下 thread 执行相同指令// warp_id 0 -> warp_m 0, warp_n 0// warp_id 1 -> warp_m 0, warp_n 1// warp_id 2 -> warp_m 1, warp_n 0// warp_id 3 -> warp_m 1, warp_n 1const int tid = threadIdx.y * blockDim.x + threadIdx.x;const int warp_id = tid / WARP_SIZE;  // 0~7 warp_id within blockconst int lane_id = tid % WARP_SIZE;  // 0~31const int warp_m = warp_id / 2;       // 0,1,2,3const int warp_n = warp_id % 2;       // 0,1// 0. 先计算 shared memory 中的索引// tid 和需要加载的 smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A 行主序// 对于 s_a 每行 16 个数据,每个线程读取 8 个,需要 2 个线程;总共 128 行,需要 128x2 刚好 256 线程const int load_smem_a_m = tid / 2;                 // row 0~127const int load_smem_a_k = (tid % 2 == 0) ? 0 : 8;  // col 0,8// tid 和需要加载的 smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B 行主序// 对于 s_b 每行 128 个数据,每个线程读 8 个数据,需要 16 个线程,总共 16 行,需要 16x16=256 个线程const int load_smem_b_k = tid / 16;                 // 0~15const int load_smem_b_n = (tid % 16) * 8;           // 0,8,16,...,120// 1. 再计算全局内存中的索引// 要加载到 s_a 中的元素对应到 A 全局内存中的行数// 每个 block 负责出 C 中大小为 BM*BN 的块const int load_gmem_a_m = by * BM + load_smem_a_m;  // global row of a and cconst int load_gmem_b_n = bx * BN + load_smem_b_n;  // global col of b and cif(load_gmem_a_m >= M || load_gmem_b_n >= N){return;}wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag[WARP_TILE_M][WARP_TILE_N];
#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::fill_fragment(C_frag[i][j], 0.0);}}#pragma unrollfor(int k = 0; k < NUM_K_TILES; ++k){int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) =(LDST128BITS(B[load_gmem_b_addr]));LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) =(LDST128BITS(A[load_gmem_a_addr]));    __syncthreads();wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[warp_smem_a_m][0], BK);  // BM*BK, BK=WMMA_K}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[0][warp_smem_b_n], BN);  // BK=BN, BK=WMMA_K}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}__syncthreads();}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, wmma::mem_row_major);}}
}void hgemm_wmma_m16n16k16_mma4x2_warp2x4(half* A, half* B, half* C, int M, int N, int K){constexpr int WMMA_M = 16;constexpr int WMMA_N = 16;constexpr int WMMA_K = 16;constexpr int WMMA_TILE_M = 4;constexpr int WMMA_TILE_N = 2;constexpr int WARP_TILE_M = 2;constexpr int WARP_TILE_N = 4;dim3 block(256);dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));hgemm_wmma_m16n16k16_mma4x2_warp2x4_kernel<WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, WARP_TILE_M, WARP_TILE_N><<<grid, block>>>(A, B, C, M, N, K);
}int main(int argc, char* argv[]){Tester tester(512, 2048, 1024, 1, 10, 100, true);tester.evaluate(hgemm_wmma_m16n16k16_mma4x2_warp2x4, "hgemm_wmma_m16n16k16_mma4x2_warp2x4");return 0;
}

下面我们对 V3 Kernel 进行细致剖析:(from ChatGPT)

1. 网格和线程组织

  • Block 大小
dim3 block(256);

每个 block 只启用 256 个线程,也就是 8 个 warp

  • Grid 大小
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M));
  • grid.x = ceil(N/128):沿输出矩阵列方向(N)划分 tile
  • grid.y = ceil(M/128):沿输出矩阵行方向(M)划分 tile

每个 block 计算一个 128x128 的大 tile,分给 4x2=8 个 warp 细分计算,每个 warp 各自负责计算一个 32x64 的输出 tile,其中 M 方向是 2 个 16x16 块共 32 行,N 方向是 4 个 16x16 块共 64 列

2. 宏定义

#define WARP_SIZE 32
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
  • WARP_SIZE:一个 warp 的线程数(32)
  • LDST32BITS/64BITS:用 float4(128 bit)做一次性向量化读/写

3. 核函数声明

template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16,const int WMMA_TILE_M = 4, const int WMMA_TILE_N = 2,const int WARP_TILE_M = 2, const int WARP_TILE_N = 4>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_kernel(half* A, half* B, half* C, int M, int N, int K){
  • WMMA_TILE_M/N = 4/2:在 M/N 方向各自处理 4/2 个 16x16 子 tile(同 V2)
  • WARP_TILE_M/N = 2/4:再在单个 warp 内,沿 M/N 方向分别处理两行、四列 16x16 子 tile
    • 所以一个 warp 负责的输出区域是 (16×2)×(16×4)=32×64(16 \times 2) \times (16 \times 4) = 32 \times 64(16×2)×(16×4)=32×64

4. Tile 尺寸和共享内存计算

const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M;  // 16×4×2 = 128
constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N;  // 16×2×4 = 128
constexpr int BK = WMMA_K;                              // 16
__shared__ half s_a[BM][BK], s_b[BK][BN];               // 128×16 + 16×128 ≈ 8 KB
  • Block 负责输出 128x128 的大 tile
  • Shared Memory 比 V2 翻了一倍,用于缓存更大范围的 A/B

5. 线程索引与子 tile 映射

int tid     = threadIdx.y * blockDim.x + threadIdx.x; // 0…255
int warp_id = tid / WARP_SIZE;                        // 0…7 (8 warps)
int lane_id = tid % WARP_SIZE;                        // 0…31
// 仍旧一个 block 有 8 warp,但每个 warp 现在做 2×4 个 WMMA 计算
int warp_m  = warp_id / 2;  // 决定 warp Block 内在 M 方向的哪一大行(0…3)
int warp_n  = warp_id % 2;  // 决定 warp 在 N 方向的哪一大列(0…1)
  • warp_m/n 计算同 V2

6. 共享内存加载索引计算

// s_a加载索引 (128×16)
const int load_smem_a_m = tid / 2;                 // 0~127
const int load_smem_a_k = (tid % 2 == 0) ? 0 : 8;  // 0或8// s_b加载索引 (16×128)
const int load_smem_b_k = tid / 16;       // 0~15
const int load_smem_b_n = (tid % 16) * 8; // 0,8,16,...,120const int load_gmem_a_m = by * BM + load_smem_a_m;  // global row of a and c
const int load_gmem_b_n = bx * BN + load_smem_b_n;  // global col of b and c
  • 访问模式
    • s_a:每行 16 元素 ➡ 2 线程/行(每个线程加载 8 个元素)
    • s_b:每行 128 元素 ➡ 16 线程/行(每个线程加载 8 个元素)

7. Fragment 的声明与初始化

wmma::fragment<wmma::accumulator, ...> C_frag[WARP_TILE_M][WARP_TILE_N];#pragma unroll
for(int i = 0; i < WARP_TILE_M; ++i){for(int j = 0; j < WARP_TILE_N; ++j){wmma::fill_fragment(C_frag[i][j], 0.0);}
}
  • 每个 warp 分配 2x4 个累加 fragment,分别对应它负责的 8 块 16x16

8. K 维分段 & Shared Memory 加载

#pragma unrollfor(int k = 0; k < NUM_K_TILES; ++k){int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) =(LDST128BITS(B[load_gmem_b_addr]));LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) =(LDST128BITS(A[load_gmem_a_addr]));    __syncthreads();
  • load_smem_…:用 256 个线程协同把 128x16 和 16x128 块加载到 SMEM
  • LDST128BITS:一次 128 位(16 half)吞吐,带宽利用率进一步提升
  • __syncthreads() 确保所有数据就绪

9. 从 SMEM 到 WMMA Fragment

        wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[warp_smem_a_m][0], BK);  // BM*BK, BK=WMMA_K}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[0][warp_smem_b_n], BN);  // BK=BN, BK=WMMA_K}
  • 每个 warp 从 SMEM 中,按自己分块的「行/列偏移」依次装载 2 块 A_frag、4 块 B_frag

10. 发起多次 Tensor Core 乘加

        #pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}__syncthreads();
  • 一个 warp 在同一 K 分段内发起 2x4=8 次 16x16x16 半精度乘加

11. 结果写回全局内存

    #pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, wmma::mem_row_major);}}
  • 逐块写回 本 warp 负责的 2x4 个 16x16 输出子块

V3 相比 V2 的主要优化点有:

1. Warp 内部多重 Tile

  • V2:每 warp 只做 1 块 16x16
  • V3:每 warp 承担 2x4 块 16x16 ➡ 输出区域从 16x16 扩到 32x64,算力利用率更高

2. 更大 Block Tile

  • V2 Block 覆盖 64x32
  • V3 Block 扩大到 128x128,减少了 kernel launch 数量与边界处理开销

3. 128-bit 向量加载

  • V2 用 half2/float2(32/64 bit)
  • V3 用 float4(128 bit)一次搬更多数据,进一步提高全局 ➡ SMEM 带宽

4. 更深度的 Register Blocking

  • V3 在每个 warp 内维护 8 个累加 fragment,利用寄存器存储更多中间结果,减少写回次数

5. 更优的算/访比

  • 单次 K 分段内,Warp 发起 8 次 mma_sync,而 SMEM 只 load 一次,提高了计算与访存的重叠效率

6. 更高的 SMEM 复用

  • V3 缓存更大片 A/B 到 SMEM,所有 warp x 多次 tile 全部复用

V3 在 V2 的基础上,通过「Warp 级二次 tiling」和「128 bit 向量加载」,使每个 warp 做更多的矩阵分块运算,从更大块的 SMEM 重用数据,从而最大化 Tensor Core 的吞吐和内存带宽利用,进一步提升 GEMM 性能

nsight compute 的性能和带宽测试结果如下:

优化手段矩阵维度GridBlock耗时(us)Memory [%]DRAM Throughout(%)Compute(SM)[%]
baseline(ampere_h1688gemm_128x128_ldg8_stages_32x1_nn)m=512,n=2048,k=1024--83.4266.7141.6169.67
hgemm_v1_wmma_m16n16k16_naive_kernelm=512,n=2048,k=1024(128,32)(32)618.1483.7220.1728.29
hgemm_v2_wmma_m16n16k16_mma4x2_kernelm=512,n=2048,k=1024(64,8)(256)309.0688.3926.3828.05
hgemm_v3_wmma_m16n16k16_mma4x2_warp2x4_kernelm=512,n=2048,k=1024(16,4)(256)257.6361.9910.3722.44

Memory Chart 内存图和 V2 一致

5. hgemm_v4_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel

在 V3 基础上,我们可以进一步为 A/B 各分配双缓冲共享内存,结合 CUDA 异步拷贝指令异步预取下一个 tile 分块数据,从而实现访存与 WMMA 计算重叠,降低同步开销并提升带宽利用率:

在这里插入图片描述

图中 V4 版本的双缓冲异步拷贝与 WMMA 计算管线可以分成以下几个阶段:(from ChatGPT)

1. 共享内存申请

  • 申请两组 shared memory 即 buffer 0 和 buffer 1

2. 首轮预加载(k=0)

  • 将 A、B 的第 0 个 tile 分块异步拷贝到共享内存的 buffer 0,并等待拷贝完成

3. 主循环:计算—预取交错(k=1 … NUM_K_TILES-1)

对于每个后续 tile 索引 k:

  • 预取下一 tile:异步拷贝 A[k]、B[k] 到共享内存 buffer 中,但 不立即等待
  • 利用当前缓冲执行 WMMA:从上一次拷贝完成的 buffer 把多组子 tile 转载到 fragment 中,然后执行 WMMA 指令计算
  • 提交并等待下一轮拷贝:计算期间,异步拷贝正进行,在 MMA 完成后再等待拷贝完成,确保下一个 tile 已拷入共享内存

4. 尾轮计算(最后一个 k)

  • 循环结束后,最后一块 tile 已在 buffer 1 中,跳过异步拷贝环节,直接从共享内存加载 fragment,执行最后一次 MMA

5. 结果写回

  • 所有子 tile 的累加结果此时都在寄存器 C_frag 数组里,逐一写回全局内存即可

这样就实现了 “拷贝下一 tile ↔ 计算当前 tile” 并行执行,最大化算/访比和硬件利用率

代码如下:

#include <iostream>
#include <cuda_runtime.h>
#include "common/tester.h"
#include "common/common.h"#define WARP_SIZE 32using namespace nvcuda;// Double buffers
template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16,const int WMMA_TILE_M = 4, const int WMMA_TILE_N = 2,const int WARP_TILE_M = 2, const int WARP_TILE_N = 4,const int OFFSET = 0>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel(half* A, half* B, half* C, int M, int N, int K){// 256 thread(8 warps) per block.const int bx = blockIdx.x;const int by = blockIdx.y;const int NUM_K_TILES = div_ceil(K, WMMA_K);constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M;  // 16x4*2=128constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N;  // 16x2*4=128constexpr int BK = WMMA_K;                              // 16// 16x128x2=4KB, 4+4=8KB, padding to reduce bank conflicts__shared__ half s_a[2][BM][BK + OFFSET], s_b[2][BK][BN + OFFSET];// 要保证相同的 warp 下 thread 执行相同指令const int tid = threadIdx.y * blockDim.x + threadIdx.x;const int warp_id = tid / WARP_SIZE;  // 0~7 warp_id within blockconst int lane_id = tid % WARP_SIZE;  // 0~31const int warp_m = warp_id / 2;       // 0,1,2,3const int warp_n = warp_id % 2;       // 0,1// 0. 先计算 shared memory 中的索引// tid 和需要加载的 smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A 行主序// 对于 s_a 每行 16 个数据,每个线程读取 8 个,需要 2 个线程;总共 128 行,需要 128x2 刚好 256 线程const int load_smem_a_m = tid / 2;                 // row 0~127const int load_smem_a_k = (tid % 2 == 0) ? 0 : 8;  // col 0,8// tid 和需要加载的 smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B 行主序// 对于 s_b 每行 128 个数据,每个线程读 8 个数据,需要 16 个线程,总共 16 行,需要 16x16=256 个线程const int load_smem_b_k = tid / 16;                 // 0~15const int load_smem_b_n = (tid % 16) * 8;           // 0,8,16,...,120// 1. 再计算全局内存中的索引// 要加载到 s_a 中的元素对应到 A 全局内存中的行数// 每个 block 负责出 C 中大小为 BM*BN 的块const int load_gmem_a_m = by * BM + load_smem_a_m;  // global row of a and cconst int load_gmem_b_n = bx * BN + load_smem_b_n;  // global col of b and cif(load_gmem_a_m >= M || load_gmem_b_n >= N){return;}wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag[WARP_TILE_M][WARP_TILE_N];
#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::fill_fragment(C_frag[i][j], 0.0);}}// k = 0 is loading here, buffer 0{int load_gmem_a_k = load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;uint32_t load_smem_a_ptr = __cvta_generic_to_shared(&s_a[0][load_smem_a_m][load_smem_a_k]);CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[0][load_smem_b_k][load_smem_b_n]);CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);CP_ASYNC_COMMIT_GROUP();CP_ASYNC_WAIT_GROUP(0);}__syncthreads();#pragma unrollfor(int k = 1; k < NUM_K_TILES; ++k){  // start from 1int smem_sel = (k - 1) & 1;        // k 1->0, k 2->1, k 3->0, ...int smem_sel_next = k & 1;         // k 1->1, k 2->0, k 3->1, ...int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;uint32_t load_smem_a_ptr = __cvta_generic_to_shared(&s_a[smem_sel_next][load_smem_a_m][load_smem_a_k]);CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[smem_sel_next][load_smem_b_k][load_smem_b_n]);CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK + OFFSET);  // BM*BK, BK=WMMA_K}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN + OFFSET);  // BK=BN, BK=WMMA_K}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}CP_ASYNC_COMMIT_GROUP();CP_ASYNC_WAIT_GROUP(0);__syncthreads();}// processing last k tile{wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[1][warp_smem_a_m][0], BK + OFFSET);}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[1][0][warp_smem_b_n], BN + OFFSET);}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}        }// finally, store back to C matrix#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, wmma::mem_row_major);}}    
}void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(half* A, half* B, half* C, int M, int N, int K){constexpr int WMMA_M = 16;constexpr int WMMA_N = 16;constexpr int WMMA_K = 16;constexpr int WMMA_TILE_M = 4;constexpr int WMMA_TILE_N = 2;constexpr int WARP_TILE_M = 2;constexpr int WARP_TILE_N = 4;constexpr int OFFSET = 0;dim3 block(256);dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N * WARP_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M * WARP_TILE_M));hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel<WMMA_M, WMMA_N, WMMA_K, WMMA_TILE_M, WMMA_TILE_N, WARP_TILE_M, WARP_TILE_N, OFFSET><<<grid, block>>>(A, B, C, M, N, K);
}int main(int argc, char* argv[]){Tester tester(512, 2048, 1024, 1, 10, 100, true);tester.evaluate(hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async, "hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async");return 0;
}

下面我们对 V4 Kernel 进行详细剖析:(from ChatGPT)

1. 网格和线程组织

  • Block 大小
dim3 block(256);

每个 block 只启用 256 个线程,也就是 8 个 warp

  • Grid 大小
dim3 grid(div_ceil(N, WMMA_N * WMMA_TILE_N), div_ceil(M, WMMA_M * WMMA_TILE_M));
  • grid.x = ceil(N/128):沿输出矩阵列方向(N)划分 tile
  • grid.y = ceil(M/128):沿输出矩阵行方向(M)划分 tile

每个 block 计算一个 128x128 的大 tile,分给 4x2=8 个 warp 细分计算,每个 warp 各自负责计算一个 32x64 的输出 tile,其中 M 方向是 2 个 16x16 块共 32 行,N 方向是 4 个 16x16 块共 64 列,和 V3 保持一致

2. 基本设置与模板参数

#define WARP_SIZE 32
  • 每个 warp 固定 32 线程
template<const int WMMA_M = 16, const int WMMA_N = 16, const int WMMA_K = 16,const int WMMA_TILE_M = 4, const int WMMA_TILE_N = 2,const int WARP_TILE_M = 2, const int WARP_TILE_N = 4,const int OFFSET = 0>
__global__ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernel(half* A, half* B, half* C, int M, int N, int K){
  • WMMA_M/N/K = 16:Tensor Core 支持的 tiling 维度
  • WMMA_TILE_M/N = 4/2:Block 级 M/N 方向 tiling
  • WARP_TILE_M/N = 2/4:Warp 级的寄存器 blocking
  • OFFSET:在 shared memory 行尾做 padding 以减少 bank 冲突

3. Block / Warp 布局与 Shared Memory 双缓冲

const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, WMMA_K);
constexpr int BM = WMMA_M * WMMA_TILE_M * WARP_TILE_M;  // 128
constexpr int BN = WMMA_N * WMMA_TILE_N * WARP_TILE_N;  // 128
constexpr int BK = WMMA_K;                          // 16// 双缓冲:两个 ping‑pong 口,一边加载一边计算
__shared__ half s_a[2][BM][BK + OFFSET];
__shared__ half s_b[2][BK][BN + OFFSET];
  • Block 负责 128x128 子矩阵,使用 4x2x2x4 结构同 V3
  • 双缓冲s_a[0]/s_b[0]s_1[1]/s_b[1] 交替使用,用于 overlap 全局 ➡ SMEM 传输与 WMMA 计算

4. 线程索引与 Warp 定位

int tid     = threadIdx.y * blockDim.x + threadIdx.x;  // 0…255
int warp_id = tid / WARP_SIZE;                        // 0…7
int lane_id = tid % WARP_SIZE;                        // 0…31
int warp_m  = warp_id / 2;  // M 方向第几大行 tile
int warp_n  = warp_id % 2;  // N 方向第几大列 tile
  • 与 V3 相同,Block 内 8 个 warp;每个 warp 负责寄存器内的 2x4 tiling 块

5. 共享内存加载索引计算

// s_a加载索引 (128×16)
const int load_smem_a_m = tid / 2;                 // 0~127
const int load_smem_a_k = (tid % 2 == 0) ? 0 : 8;  // 0或8// s_b加载索引 (16×128)
const int load_smem_b_k = tid / 16;       // 0~15
const int load_smem_b_n = (tid % 16) * 8; // 0,8,16,...,120const int load_gmem_a_m = by * BM + load_smem_a_m;  // global row of a and c
const int load_gmem_b_n = bx * BN + load_smem_b_n;  // global col of b and c
  • 访问模式
    • s_a:每行 16 元素 ➡ 2 线程/行(每个线程加载 8 个元素)
    • s_b:每行 128 元素 ➡ 16 线程/行(每个线程加载 8 个元素)

6. Fragment 的声明与初始化

wmma::fragment<wmma::accumulator, ...> C_frag[WARP_TILE_M][WARP_TILE_N];#pragma unroll
for(int i = 0; i < WARP_TILE_M; ++i){for(int j = 0; j < WARP_TILE_N; ++j){wmma::fill_fragment(C_frag[i][j], 0.0);}
}
  • 每个 warp 分配 2x4 个累加 fragment,分别对应它负责的 8 块 16x16

7. 首段 K‑tile 异步预取(ping buffer)

// k = 0 is loading here, buffer 0
{int load_gmem_a_k = load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;uint32_t load_smem_a_ptr = __cvta_generic_to_shared(&s_a[0][load_smem_a_m][load_smem_a_k]);CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[0][load_smem_b_k][load_smem_b_n]);CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);CP_ASYNC_COMMIT_GROUP();// 等待该 group 完成CP_ASYNC_WAIT_GROUP(0);
}
__syncthreads();
  • CP_ASYNC_CG:使用 CUDA 的异步拷贝指令将 16x16 数据行从全局加载到 Shared Memory
  • CP_ASYNC_COMMIT_GROUP + WAIT_GROUP(0):提交并等待该批次完成,保证第一段数据可用

这里我们用到了一些非常底层的 PTX 式异步拷贝指令,下面简单解释下:(from ChatGPT)

__cvta_generic_to_shared(ptr) 函数的功能是在设备端实际去执行 PTX 的 cvta.to.shared 指令,把一个 “通用”(generic)指针转换成 “shared” 地址空间下的原始位模式,没有它我们就无法在内联 PTX(如 cp.async)中正确指定 shared memory 目标

关于这个函数更详细的说明大家可以参考 StackOverflow 上的这个回答:https://stackoverflow.com/questions/76992939/confusion-about-cvta-generic-to-shared

NotePTX(Parallel Thread Execution) 是 NVIDIA 的虚拟指令集(ISA),CUDA C++ 编译器会先把你的内核(Kernel)翻译成 PTX,然后 Driver JIT 再把 PTX 转成最终 GPU 的机器码。我们可以把 PTX 想象成 NVIDIA GPU 的 “中间汇编语言” 或 “虚拟机器码”—就像 Java 的字节码(bytecode)或编译器里的中间表示(IR)一样

CP_ASYNC_CG / CP_ASYNC_COMMIT_GROUP / CP_ASYNC_WAIT_GROUP 这些宏都直接包装了 PTX 的异步拷贝指令:

CP_ASYNC_CG(dst, src, Bytes)

#define CP_ASYNC_CG(dst, src, Bytes) \asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(Bytes))
#else

意味着从全局内存 src 异步拷贝 Bytes 字节到 shared memory dst 中,使用 L2 缓存路径(cg

CP_ASYNC_COMMIT_GROUP()

#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)

把上面所有 cp.async 指令归为一组,标记它们可以一起等待

CP_ASYNC_WAIT_GROUP(n)

#define CP_ASYNC_WAIT_GROUP(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))

等待组号 ≥n 的所有拷贝完成。通常用 wait_group(0) 就是等所有未完成的异步拷贝都做完,保证数据可见

8. 主循环:双缓冲切换+计算+预取下一段

#pragma unroll
for(int k = 1; k < NUM_K_TILES; ++k){  // start from 1// 正在计算的 buffer 索引(ping)int smem_sel = (k - 1) & 1;        // k 1->0, k 2->1, k 3->0, ...// 下一段预取到的 buffer(pong)int smem_sel_next = k & 1;         // k 1->1, k 2->0, k 3->1, ...int load_gmem_a_k = k * WMMA_K + load_smem_a_k;  // global col of aint load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;int load_gmem_b_k = k * WMMA_K + load_smem_b_k;  // global row of bint load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;// 8.1 异步预取到 ponguint32_t load_smem_a_ptr = __cvta_generic_to_shared(&s_a[smem_sel_next][load_smem_a_m][load_smem_a_k]);CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16);uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[smem_sel_next][load_smem_b_k][load_smem_b_n]);CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);// 8.2 从 ping 中加载到 fragment,并发起 8 次 mmawmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[smem_sel][warp_smem_a_m][0], BK + OFFSET);  // BM*BK, BK=WMMA_K}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[smem_sel][0][warp_smem_b_n], BN + OFFSET);  // BK=BN, BK=WMMA_K}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}// 8.3 提交并等待该预取 group 完成CP_ASYNC_COMMIT_GROUP();CP_ASYNC_WAIT_GROUP(0);__syncthreads();
}
  • 双缓冲切换curr buffer 用于当前计算,next buffer 并行预取下一段
  • 异步+并发:在计算 ping 数据时,同时发起对 pong 的 cp.async,最大化隐藏访存延迟

9. 处理最后一段 K-tile

// 处理最后剩余一个 ping buffer 的计算
{// load_matrix_sync + mma_sync 与上面相同wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> A_frag[WARP_TILE_M];wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> B_frag[WARP_TILE_N];#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){// load 2 tiles -> reg, smem a -> frags a, warp_m 0~3const int warp_smem_a_m = warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;wmma::load_matrix_sync(A_frag[i], &s_a[1][warp_smem_a_m][0], BK + OFFSET);}#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){// load 4 tiles -> reg, smem b -> frags b, warp_n 0~2const int warp_smem_b_n = warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::load_matrix_sync(B_frag[j], &s_b[1][0][warp_smem_b_n], BN + OFFSET);}#pragma unrollfor(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){wmma::mma_sync(C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]);}}        
}
  • 因为主循环预取了 NUM_K_TILES-1 段,最后还需对最后一段的 ping buffer 做一次加载+mma

10. 写回全局 C

// finally, store back to C matrix
#pragma unroll
for(int i = 0; i < WARP_TILE_M; ++i){#pragma unrollfor(int j = 0; j < WARP_TILE_N; ++j){const int store_gmem_a_m = by * BM + warp_m * (WMMA_M * WARP_TILE_M) + i * WMMA_M;const int store_gmem_a_n = bx * BN + warp_n * (WMMA_N * WARP_TILE_N) + j * WMMA_N;wmma::store_matrix_sync(C + store_gmem_a_m * N + store_gmem_a_n, C_frag[i][j], N, wmma::mem_row_major);}
} 
  • 与 V3 相同,逐块写回寄存器中累加的 8 个 16x16 子 tile

V4 相比 V3 的改进与优化点有:

1. 双缓冲 Shared Memory

  • V3 每段计算前后都要同步载入全局 ➡ SMEM
  • V4 用两组 ping-pong buffer,将加载与计算交错进行,减少 __syncthreads() 的阻塞时长

2. 异步拷贝指令(cp.async)

  • V3 靠向量化加载 LDST128BITS__syncthreads()
  • V4 用 CUDA 提供的 CP_ASYNC_CG 系列指令,让数据拷贝在硬件管道中完成,并且可以在不阻塞计算的情况下打包提交和等待

3. 访存/计算完全重叠

  • V4 在计算当前段的同时发起对下一段的预取,最大化利用内存带宽和 Tensor Core,进一步提高算法/访存比

4. 更少的显式同步

  • 只在每段预取提交后用 group-level wait(CP_ASYNC_WAIT_GROUP)和一个 __syncthreads(),避免了全 block 过早等待

5. Padding 减少 bank 冲突

  • OFFSET 参数可在 shared memory 末尾做行对齐,减少多个 warp 同时访问不同 buffer 时的 bank 冲突

V4 通过 “双缓冲+异步拷贝+交错计算” 模式,将全局内存到共享内存的数据传输与 Tensor Core 计算深度融合,最大程度隐藏访存延迟,显著提升了流水线效率和硬件吞吐

nsight compute 的性能和带宽测试结果如下:

优化手段矩阵维度GridBlock耗时(us)Memory [%]DRAM Throughout(%)Compute(SM)[%]
baseline(ampere_h1688gemm_128x128_ldg8_stages_32x1_nn)m=512,n=2048,k=1024--83.4266.7141.6169.67
hgemm_v1_wmma_m16n16k16_naive_kernelm=512,n=2048,k=1024(128,32)(32)618.1483.7220.1728.29
hgemm_v2_wmma_m16n16k16_mma4x2_kernelm=512,n=2048,k=1024(64,8)(256)309.0688.3926.3828.05
hgemm_v3_wmma_m16n16k16_mma4x2_warp2x4_kernelm=512,n=2048,k=1024(16,4)(256)257.6361.9910.3722.44
hgemm_v4_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_kernelm=512,n=2048,k=1024(16,4)(256)199.7468.1414.1128.61

Memory Chart 内存图和 V3 一致

OK,以上就是 hgemm 各种优化的代码实现了

结语

这篇文章我们学习了利用更底层的 API 即 CUDA C++ WMMA API 来实现 hgemm 矩阵乘法运算

我们跟随 UP 分析了 LeetCUDA 仓库中 hgemm 4 个版本的实现,首先原生即 V1 版本仅使用 WMMA API 实现 hgemm,每个 block 仅包含一个 warp,调用一次 mma 指令

V2 版本考虑让每个 block 开启更多的线程,计算更大的输出 tile 来提高 occupancy,同时还利用了共享内存来降低全局带宽消耗

V3 版本在 V2 版本的基础上让每个 warp 同时累加多个子 tile(如 2x4 块)来提高算/访比

V4 版本则通过双缓冲机制,结合 CUDA 异步拷贝指令预取下一个 tile 分块数据,实现访存与 WMMA 计算的重叠,从而降低同步开销并提升带宽利用率

如果大家之前学习过 sgemm 的一些优化,会发现这里的 hgemm 优化思路和 sgemm 其实非常像,都是沿着 提升算/访比、最大化并行度、最小化内存带宽压力和同步开销 这几条主线不断演进,只不过需要注意的是 mma 指令是 warp 级别的,时时刻刻要考虑的是一个 warp

总的来说,这几个版本的 hgemm 实现相对还是比较好理解的,大家感兴趣的可以多看看 up 主的视频,还是非常不错的🤗

下载链接

  • Hgemm 矩阵乘法代码下载链接【提取码:1234】

参考

  • 【CUDA进阶】Tensor Core实战教程(已完结)
  • https://github.com/xlite-dev/LeetCUDA
  • https://github.com/Bruce-Lee-LY/cuda_hgemm
  • https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma
  • https://chatgpt.com
http://www.dtcms.com/a/289525.html

相关文章:

  • 【JS逆向基础】数据库之redis
  • TypeScript环境安装和操作
  • 将 VHD/VHDX 转换为物理磁盘
  • 无 sudo 权限的环境下将 nvcc (CUDA Toolkit) 安装到个人目录 linux
  • 虚拟地址空间
  • rman清理归档
  • 2024年全国青少年信息素养大赛Scratch编程挑战赛 小低组初赛
  • 【JDK内置工具】常用工具和实战指令
  • 贝叶斯分类器的相关理论学习
  • 力扣面试150(34/150)
  • 人脸识别:AI 如何精准 “认人”?
  • Florence2-通用表征完成多种视觉任务的视觉基础模型
  • 最新轻量美化表白墙系统源码v2.0 带后台版 附搭建教程
  • 分治算法---归并
  • 智能制造——48页毕马威:汽车营销与研发数字化研究【附全文阅读】
  • Muduo库中单例模式详解
  • 【Anaconda】Conda 虚拟环境打包迁移教程
  • 基于ACPs协议的智能体互联网示例(多智能体旅游规划)
  • JMeter连接数据库
  • Linux操作系统从入门到实战(十一)回车换行问题与用户缓冲区问题
  • C++虚函数易错点整理
  • 20250720-4-Kubernetes 调度-指定节点调度:nodeSelectornodeAffinity笔记
  • LeetCode 3202.找出有效子序列的最大长度 II:取模性质(动态规划)
  • JDK8默认垃圾回收器
  • (Python)类和类的方法进阶(基础教程介绍)(Python基础教程)
  • 利用核壳生物支架调控纤维 - 成骨稳态【AbMole】
  • Linux:线程控制
  • 【网络编程】网络传输-JSON
  • 【C语言】字符串与字符函数详解(下)
  • Shell脚本-cut工具