【CUDA】Sgemm单精度矩阵乘法(下)
目录
- 前言
- 1. 优化技巧5:使用register模拟二级缓存(内积转外积)
- 2. 优化技巧6:使用register模拟二级缓存 + float4
- 3. 优化技巧7:global memory转置再存放shared memory
- 4. 优化技巧8:使用double buffer加速矩阵乘法
- 结语
- 下载链接
- 参考
前言
学习 UP 主 比飞鸟贵重的多_HKL 的 【CUDA】Sgemm单精度矩阵乘法(已完结~) 视频,记录下个人学习笔记,仅供自己参考😄
refer 1:【CUDA】Sgemm单精度矩阵乘法(已完结~)
refer 2:https://github.com/tpoisonooo/how-to-optimize-gemm/cuda
refer 3:https://chatgpt.com/
1. 优化技巧5:使用register模拟二级缓存(内积转外积)
我们接着上篇文章来讲解 sgemm 的优化
前面在 v2 版本中我们通过分块的方式将数据从 global memory 放置在 shared memory 中,大大减少了访存所需要的时延,这里,我们进一步考虑从 shared memory 到 register 的过程,如下图所示:
Note:图片来自于:深入浅出GPU优化系列:GEMM优化(一)
通过寄存器来模拟二级缓存,可以将内积形式转换为外积形式,如下图所示:
上图左边展示的是经典内积的形式,通过三层循环,每次计算 C_tile[m][n]
时,从 A_tile[m][k]
(蓝色行)与 B[k][n]
(蓝色列)中加载值,计算后逐步累加
上图右边改写为按 k
为最外层的循环,变为外积实现。先固定一个 k
,将 A_tile[:,k]
(蓝色列向量)和 B_tile[k,:]
加载到寄存器中,对 C_tile
的一个子块进行外积更新,相当于更新一个 rank-1 子矩阵,这样可以就减少对全体 A_tile/B_tile
数据的重复加载,起到模拟二级缓存(寄存器暂存)的作用
下图展示了利用 register 内积转外积的整体流程:
在将数据从 global memory 加载到 shared memory 之后,还需要进一步加载到 register 中,接着通过外积计算方式逐步累加得到最终的结果
下图还对比了 v4 和 v5 版本的差异,主要体现在寄存器的使用以及内积转外积的实现:
代码如下:
template<unsigned int BLOCK_SIZE, unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v5_register_outer_product(float* A, float* B, float* C, const int M, const int N, const int K) {int row = blockIdx.y * blockDim.y + threadIdx.y;int col = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_PER_THREAD;extern __shared__ float shared_mem[];float* A_tile = shared_mem;float* B_tile = shared_mem + BLOCK_SIZE * BLOCK_SIZE;constexpr int REG_NUM = NUM_PER_THREAD / 2;float A_reg[REG_NUM] = {0.0f};float B_reg[REG_NUM] = {0.0f};float sum[REG_NUM * REG_NUM] = {0.0f};// re-arrange the layoutint tid = threadIdx.y * blockDim.x + threadIdx.x;int ctx = tid % (BLOCK_SIZE / REG_NUM);int cty = tid / (BLOCK_SIZE / REG_NUM);for(int k_base = 0; k_base < K; k_base += BLOCK_SIZE){// load A_tile from global memory to shared memoryint a_col = k_base + threadIdx.x * NUM_PER_THREAD;FLOAT4(A_tile[threadIdx.y * BLOCK_SIZE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(A[row * K + a_col]);// load B_tile from global memory to shared memoryint b_row = k_base + threadIdx.y;FLOAT4(B_tile[threadIdx.y * BLOCK_SIZE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(B[b_row * N + col]);__syncthreads();// use register to compute the sum of A_tile * B_tile for(int k = 0; k < BLOCK_SIZE; ++k){A_reg[0] = A_tile[(cty * REG_NUM) * BLOCK_SIZE + k];A_reg[1] = A_tile[(cty * REG_NUM + 1) * BLOCK_SIZE + k];B_reg[0] = B_tile[k * BLOCK_SIZE + ctx * REG_NUM];B_reg[1] = B_tile[k * BLOCK_SIZE + ctx * REG_NUM + 1];for(int i = 0; i < REG_NUM; ++i){for(int j = 0; j < REG_NUM; ++j){sum[i * REG_NUM + j] += A_reg[i] * B_reg[j];}}}__syncthreads(); }// write the result to Cfloat* C_start = C + blockIdx.y * blockDim.y * N + blockIdx.x * blockDim.x * NUM_PER_THREAD;for(int i = 0; i < REG_NUM; ++i){for(int j = 0; j < REG_NUM; ++j){C_start[(cty * REG_NUM + i) * N + ctx * REG_NUM + j] = sum[i * REG_NUM + j];}}
}
下面是该代码的详细分析:(from ChatGPT)
1. 核函数签名和参数
template<unsigned int BLOCK_SIZE, unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v5_register_outer_product(float* A, float* B, float* C, const int M, const int N, const int K)
- 模板参数:
BLOCK_SIZE
:分块(TILE)的大小,也就是每次只加载BLOCK_SIZE * BLOCK_SIZE
大小的元素到共享内存中,在示例中BLOCK_SIZE = 16
NUM_PER_THREAD
:每个线程处理的元素数量,在示例中NUM_PER_THREAD = 4
也就是每个线程负责 C 矩阵中 2x2 大小的元素,
2. 线程索引计算
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_PER_THREAD;
row
:当前线程处理的 A 矩阵的行索引col
:当前线程处理的 B 矩阵的起始列索引,注意由于每个线程处理 4 个连续元素,因此需要乘以NUM_PER_THREAD
3. 共享内存分配
extern __shared__ float shared_mem[];
float* A_tile = shared_mem;
float* B_tile = shared_mem + BLOCK_SIZE * BLOCK_SIZE;
- 共享内存布局:
- 动态共享内存
shared_mem
被划分为两部分:A_tile
:前BLOCK_SIZE * BLOCK_SIZE = 256
个 float,用于缓存 A 矩阵的分块B_tile
:后 256 个 float,用于缓存 B 矩阵的分块
- 总共享内存需求:512 个 float(2KB)
- 动态共享内存
4. 寄存器变量声明
constexpr int REG_NUM = NUM_PER_THREAD / 2;
float A_reg[REG_NUM] = {0.0f};
float B_reg[REG_NUM] = {0.0f};
float sum[REG_NUM * REG_NUM] = {0.0f};
REG_NUM
:NUM_PER_THREAD = 4
,所以REG_NUM = 2
- 这个设计意味着每个线程将处理 2x2=4 个输出元素
- 寄存器变量:
A_reg[2]
:缓存从 A_tile 加载的数据B_reg[2]
:缓存从 B_tile 加载的数据sum[4]
:累加 2x2 输出块的部分和
5. 数据布局重排
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int ctx = tid % (BLOCK_SIZE / REG_NUM);
int cty = tid / (BLOCK_SIZE / REG_NUM);
tid
计算:- 线性化的线程 ID,当前线程在 block 中的全局索引,范围 0~63(4x16 线程块)
ctx
和cty
计算:BLOCK_SIZE / REG_NUM = 16 / 2 = 8
ctx = tid % 8
:线程在 x 方向的逻辑索引(0~7)cty = tid / 8
:线程在 y 方向的逻辑索引(0~7)- 这种重排将 64 个线程组织为 8x8 的网格,每个线程负责 2x2 的输出块,如下图所示
6. 主计算循环
for(int k_base = 0; k_base < K; k_base += BLOCK_SIZE)
- 循环结构:
- 在 K 维度上分块处理,步长为
BLOCK_SIZE = 16
- 对于 K = 512,共需要 512 / 16 = 32 次迭代
- 在 K 维度上分块处理,步长为
6.1 A_tile
加载
int a_col = k_base + threadIdx.x * NUM_PER_THREAD;
FLOAT4(A_tile[threadIdx.y * BLOCK_SIZE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(A[row * K + a_col]);
- 加载逻辑:
a_col
:当前处理的 A 矩阵列索引- 使用
FLOAT4
宽指令一次加载 A 矩阵中 4 个连续的 float - 写入共享内存
A_tile
中,按行主序排序
Note:关于索引的计算博主在 v4 版本中已经讲过了,这边就不再赘述了
6.2 B_tile
加载
int b_row = k_base + threadIdx.y;
FLOAT4(B_tile[threadIdx.y * BLOCK_SIZE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(B[b_row * N + col]);
- 加载逻辑:
b_row
:当前处理的 B 矩阵行索引- 同样使用
FLOAT4
宽指令加载 B 矩阵的 4 个连续元素 - 写入共享内存
B_tile
中,按行主序排序
6.3 同步
__syncthreads();
- 确保所有线程完成共享内存的加载后才开始计算
6.4 寄存器缓存计算
for(int k = 0; k < BLOCK_SIZE; ++k){A_reg[0] = A_tile[(cty * REG_NUM) * BLOCK_SIZE + k];A_reg[1] = A_tile[(cty * REG_NUM + 1) * BLOCK_SIZE + k];B_reg[0] = B_tile[k * BLOCK_SIZE + ctx * REG_NUM];B_reg[1] = B_tile[k * BLOCK_SIZE + ctx * REG_NUM + 1];
- 寄存器加载:
- 每次迭代处理 K 维度的一个元素(k=0 到 15)
- 从
A_tile
加载两行到A_reg
(由cty
决定) - 从
B_tile
加载两列到B_reg
(由ctx
决定) - 这种访问模式确保了合并的内存访问
- 外积计算:
for(int i = 0; i < REG_NUM; ++i){for(int j = 0; j < REG_NUM; ++j){sum[i * REG_NUM + j] += A_reg[i] * B_reg[j];}}
}
- 计算模式:
- 这是典型的外积计算:A 的列向量(2 元素)与 B 的行向量(2 元素)相乘,得到 2x2 的矩阵
- 结果累加到
sum
数组中 - 共进行
BLOCK_SIZE = 16
次外积累加
博主绘制了一个草图来简要说明整个流程,如下图所示(以 (ctx,cty)=(0,0)
线程为例):
每个线程处理 A_tile
两行数据与 B_tile
两列数据相乘,每次加载 A_tile
两个数据和 B_tile
两个数据到寄存器,循环 BLOCK_SIZE
次
6.5 第二次同步
__syncthreads();
- 确保所有线程完成当前块的计算后再加载下一块
7. 结果写回
float* C_start = C + blockIdx.y * blockDim.y * N + blockIdx.x * blockDim.x * NUM_PER_THREAD;
for(int i = 0; i < REG_NUM; ++i){for(int j = 0; j < REG_NUM; ++j){C_start[(cty * REG_NUM + i) * N + ctx * REG_NUM + j] = sum[i * REG_NUM + j];}
}
C_start
计算:- 计算当前线程块对应的的 C 矩阵起始位置
blockIdx.y * blockDim.y * N
:当前线程块在 y 方向的偏移blockIdx.x * blockDim.x * NUM_PER_THREAD
:当前线程块在 x 方向的偏移
- 写回逻辑:
- 每个线程将其累加的 2x2 结果块写回全局内存
- 使用
cty
和ctx
确定写入位置,保持与计算时相同的布局
这里博主在自己实现时发现始终不能很好的定位到 C 的全局索引,绕着绕着就把自己给绕晕了
那其实 UP 主的方法非常的有效,我们先定位到当前 block 对应的矩阵 C 的位置,然后再来处理具体 block 中每个 thread 的部分,这样我们就只要关注每个 block 的索引计算就行了,会简单不少
它也可以写成下面的这种形式:
for(int i = 0; i < REG_NUM; ++i){for(int j = 0; j < REG_NUM; ++j){int c_row = blockIdx.y * blockDim.y + cty * REG_NUM + i;int c_col = blockIdx.x * blockDim.x * NUM_PER_THREAD + ctx * REG_NUM + j;C[c_row * N + c_col] = sum[i * REG_NUM + j];}
}
这个核函数使用寄存器来模拟二级缓存,其中:
- 寄存器缓存层次:
- 第一级:共享内存缓存全局内存数据
- 第二级:寄存器缓存共享内存数据
- 这种层次结构减少了共享内存的访问压力
- 数据流分析:
- 全局内存 ➡ 共享内存(通过
FLOAT4
宽加载) - 共享内存 ➡ 寄存器(标量加载)
- 寄存器 ➡ 计算单元(高效计算)
- 全局内存 ➡ 共享内存(通过
- 性能优化点:
- 每个线程的
A_reg
和B_reg
在循环中重复使用,减少共享内存访问 - 外积计算模式最大化数据复用率
- 寄存器访问完全无冲突,延迟极低
- 每个线程的
此外,这个核函数还将内积计算转换为外积计算,博主没有太理解内积转外积过程,这边简要说明下:(from ChatGPT)
首先我们需要理解内积和外积这两个基本概念,再来分析它们在矩阵乘法中的应用
1. 内积(点积)
数学定义:内积是两个向量的乘积,结果是一个标量(单个数值)
对于两个 n n n 维向量 a = [ a 1 , a 2 , … , a n ] \mathbf{a} = [a_1,a_2,\ldots,a_n] a=[a1,a2,…,an] 和 b = [ b 1 , b 2 , … , b n ] \mathbf{b} = [b_1,b_2,\ldots,b_n] b=[b1,b2,…,bn]
内积计算方式如下:
a ⋅ b = a 1 b 1 + a 2 b 2 + … + a n b n = ∑ i = 1 n ( a i b i ) \mathbf{a} \cdot \mathbf{b} = a_1b_1+a_2b_2+\ldots+a_nb_n=\sum_{i=1}^{n}(a_ib_i) a⋅b=a1b1+a2b2+…+anbn=i=1∑n(aibi)
传统矩阵乘法就是基于内积的:
for(int m = 0; m < M; m++){for(int n = 0; n < N; n++){float sum = 0;for(int k = 0; k < K; k++){sum += A[m * K + k] * B[k * N + n];}C[m * N + n] = sum;}
}
特点:
- 每个输出元素需要遍历 K 维度
- 内存访问模式不连续(A 按行访问,B 按列访问)
- 计算访存比低
2. 外积
数学定义:外积是两个向量的乘积,结果是一个矩阵
对于向量 u = [ u 1 , u 2 , … , u m ] \mathbf{u} = [u_1,u_2,\ldots,u_m] u=[u1,u2,…,um] 和 v = [ v 1 , v 2 , … , v n ] \mathbf{v} = [v_1,v_2,\ldots,v_n] v=[v1,v2,…,vn]
外积计算方式如下:
u × v T = [ u 1 u 2 ⋮ u m ] [ v 1 v 2 ⋯ v n ] = [ u 1 v 1 u 1 v 2 ⋯ u 1 v n u 2 v 1 u 2 v 2 ⋯ u 2 v n ⋮ ⋮ ⋱ ⋮ u m v 1 u m v 2 ⋯ u m v n ] \mathbf{u} \times \mathbf{v}^T = \begin{bmatrix} u_1 \\ u_2 \\ \vdots \\ u_m \end{bmatrix} \begin{bmatrix} v_1 & v_2 & \cdots & v_n \end{bmatrix}= \begin{bmatrix} u_1 v_1 & u_1 v_2 & \cdots & u_1 v_n \\ u_2 v_1 & u_2 v_2 & \cdots & u_2 v_n \\ \vdots & \vdots & \ddots & \vdots \\ u_m v_1 & u_m v_2 & \cdots & u_m v_n \end{bmatrix} u×vT= u1u2⋮um [v1v2⋯vn]= u1v1u2v1⋮umv1u1v2u2v2⋮umv2⋯⋯⋱⋯u1vnu2vn⋮umvn
外积方法将矩阵乘法计算重构为多个小矩阵的乘积累加,在核函数中的具体表现为:
1. 数据分块:
- 将
A_tile
按列分块(16x2 的小块) - 将
B_tile
按行分块(2x16 的小块)
2. 计算单元:
// 外积计算核心代码
A_reg[0] = A_tile[...]; // 加载A的一列2个元素
A_reg[1] = A_tile[...];B_reg[0] = B_tile[...]; // 加载B的一行2个元素
B_reg[1] = B_tile[...];// 计算2x2外积并累加
for(int i = 0; i < 2; i++)for(int j = 0; j < 2; j++)sum[i * 2 + j] += A_reg[i] * B_reg[j];
3. 数学表示:
每次迭代计算的是:
[ A 1 A 2 ] [ B 1 B 2 ] = [ A 1 B 1 A 1 B 2 A 2 B 1 A 2 A 2 ] \begin{bmatrix} A_1 \\ A_2 \\ \end{bmatrix} \begin{bmatrix} B_1 \ B_2 \end{bmatrix}= \begin{bmatrix} A_1 B_1 & A_1 B_2\\ A_2 B_1 & A_2 A_2\\ \end{bmatrix} [A1A2][B1 B2]=[A1B1A2B1A1B2A2A2]
然后将这些 2x2 的小矩阵累加到最终结果中
与传统内积方法相比,外积更适合 GPU,这主要是因为:
- 更高的数据复用:
- 在计算 2x2 外积时,A 的 2 个元素与 B 的 2 个元素产生 4 次乘加
- 相比内积方法(1 个 A 元素 x 1 个 B 元素 = 1 次乘加),计算密度提高 4 倍
- 更连续的内存访问:
- 外积方法中,A 按列访问,B 按行访问,都是连续内存访问
- 内积方法中,B 必须按列访问,导致非连续访问
- 更适合 SIMT 架构:
- GPU 擅长并行执行相同指令
- 外积方法让每个线程处理一个小矩阵,并行度高
- 内积方法线程间计算模式差异大,不利于并行
- 更高的计算强度:
- 每次加载的数据参与更多计算操作
- 更好地隐藏内存访问延迟
性能和带宽测试结果如下:
优化手段 | 矩阵维度 | Grid | Block | 耗时(us) | Memory Throughout(%) | DRAM Throughout(%) |
---|---|---|---|---|---|---|
v0_global_memory | 512x512 | (32,32) | (16,16) | 471.78 | 96.94 | 1.56 |
v1_shared_memory | 256x256 | (16,16) | (16,16) | 82.11 | 78.92 | 1.84 |
v2_shared_memory_sliding_window | 512x512 | (32,32) | (16,16) | 362.50 | 94.45 | 7.05 |
v3_increase_work_of_per_thread | 512x512 | (16,16) | (16,16) | 204.26 | 84.01 | 3.64 |
v4_using_float4 | 512x512 | (32,32) | (4,16) | 209.60 | 91.99 | 3.44 |
v5_register_outer_product | 512x512 | (32,32) | (4,16) | 206.18 | 79.10 | 3.50 |
2. 优化技巧6:使用register模拟二级缓存 + float4
这个小节我们尝试将 v4 和 v5 版本融合起来,既使用寄存器外积形式,又使用 float4 向量化加载,流程如下图所示:
由于我们需要使用 float4 来向量化加载,因此寄存器的数量我们各增加到 4 个,此时每个线程负责处理 4x4 大小的数据,如上图所示。另外由于 B_tile
中的四个元素在内存中是连续的,因此可以使用 float4 加载,而对应的 A_tile
中的四个元素并不是连续存储的,所以对于 A_tile
我们还是来一个个手动加载到寄存器中
代码如下:
template<unsigned int NUM_PER_TILE, unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v6_register_outer_product_float4(float* A, float* B, float* C, const int M, const int N, const int K){int row = (blockIdx.y * blockDim.y + threadIdx.y) * NUM_PER_THREAD;int col = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_PER_THREAD;extern __shared__ float shared_mem[];float* A_tile = shared_mem;float* B_tile = shared_mem + NUM_PER_TILE * NUM_PER_TILE;float A_reg[NUM_PER_THREAD] = {0.0f};float B_reg[NUM_PER_THREAD] = {0.0f};float sum[NUM_PER_THREAD * NUM_PER_THREAD] = {0.0f};for(int k_base = 0; k_base < K; k_base += NUM_PER_TILE){for(int i = 0; i < NUM_PER_THREAD; ++i){// load A_tile from global memory to shared memoryint a_col = k_base + threadIdx.x * NUM_PER_THREAD;FLOAT4(A_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(A[(row + i) * K + a_col]);// load B_tile from global memory to shared memoryint b_row = k_base + threadIdx.y * NUM_PER_THREAD;FLOAT4(B_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(B[(b_row + i) * N + col]);}__syncthreads();// use register to compute the sum of A_tile * B_tilefor(int k = 0; k < NUM_PER_TILE; ++k){A_reg[0] = A_tile[(threadIdx.y * NUM_PER_THREAD + 0) * NUM_PER_TILE + k];A_reg[1] = A_tile[(threadIdx.y * NUM_PER_THREAD + 1) * NUM_PER_TILE + k];A_reg[2] = A_tile[(threadIdx.y * NUM_PER_THREAD + 2) * NUM_PER_TILE + k];A_reg[3] = A_tile[(threadIdx.y * NUM_PER_THREAD + 3) * NUM_PER_TILE + k];FLOAT4(B_reg[0]) = FLOAT4(B_tile[k * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]);for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}}__syncthreads(); }// write the result to Cfloat* C_start = C + blockIdx.y * NUM_PER_TILE * N + blockIdx.x * NUM_PER_TILE;for(int i = 0; i < NUM_PER_THREAD; ++i){FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(sum[i * NUM_PER_THREAD]);}
}
下面是该代码的详细分析:(from ChatGPT)
1. 核函数签名与参数
template<unsigned int NUM_PER_TILE, unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v6_register_outer_product_float4(float* A, float* B, float* C, const int M, const int N, const int K)
- 模板参数:
NUM_PER_TILE
:每个线程块每次处理的分块矩阵的大小,也就是将NUM_PER_TILE * NUM_PER_TILE
大小的矩阵加载到共享内存中,相当于之前的BLOCK_SIZE
变量,在当前示例中其值为 64NUM_PER_THREAD
:每个线程处理的元素数量,在当前示例中为 4- 注意,这里可能会引发歧义,实际上在当前核函数中每个线程处理的元素数量是 4x4=16,
NUM_PER_THREAD
在这里表示的是行和列方向上各处理 4 个元素
- 注意,这里可能会引发歧义,实际上在当前核函数中每个线程处理的元素数量是 4x4=16,
2. 线程索引计算
int row = (blockIdx.y * blockDim.y + threadIdx.y) * NUM_PER_THREAD;
int col = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_PER_THREAD;
row
:当前线程处理的 A 矩阵的行索引,乘以NUM_PER_THREAD
得到实际起始行col
:当前线程处理的 B 矩阵的列索引,乘以NUM_PER_THREAD
得到实际起始列
3. 共享内存分配
extern __shared__ float shared_mem[];
float* A_tile = shared_mem;
float* B_tile = shared_mem + NUM_PER_TILE * NUM_PER_TILE;
- 动态共享内存
shared_mem
被划分为两部分:A_tile
:前NUM_PER_TILE * NUM_PER_TILE = 4096
个 float,用于缓存 A 矩阵的分块B_tile
:后 4096 个 float,用于缓存 B 矩阵的分块
- 总共享内存需求:8192 个 float(32KB)
4. 寄存器变量声明
float A_reg[NUM_PER_THREAD] = {0.0f}; // 4个寄存器缓存A数据
float B_reg[NUM_PER_THREAD] = {0.0f}; // 4个寄存器缓存B数据
float sum[NUM_PER_THREAD * NUM_PER_THREAD] = {0.0f}; // 16个寄存器存储结果
- 每个线程:
- 缓存 4 个 A 元素和 4 个 B 元素
- 累加 4x4=16 个部分和
5. 主计算循环
for(int k_base = 0; k_base < K; k_base += NUM_PER_TILE)
- 循环结构:
- 在 K 维度上分块处理,步长为
NUM_PER_TILE = 64
- 对于 K = 512,共需要 512 / 64 = 8 次迭代
- 在 K 维度上分块处理,步长为
5.1 数据加载阶段
for(int i = 0; i < NUM_PER_THREAD; ++i){// 加载A_tileint a_col = k_base + threadIdx.x * NUM_PER_THREAD;FLOAT4(A_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(A[(row + i) * K + a_col]);// 加载B_tileint b_row = k_base + threadIdx.y * NUM_PER_THREAD;FLOAT4(B_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(B[(b_row + i) * N + col]);
}
A_tile
加载- 每个线程加载 4 个
FLOAT4
(16 个 float) - 访问模式:连续访问 A 的 4 行
- 使用
FLOAT4
实现向量化加载
- 每个线程加载 4 个
B_tile
加载- 同样加载 4 个
FLOAT4
- 访问模式:连续访问 B 的 4 列
- 同样加载 4 个
5.2 同步
__syncthreads();
- 确保所有线程完成共享内存的加载后才开始计算
5.3 外积计算
for(int k = 0; k < NUM_PER_TILE; ++k){// 加载A的4个元素(一列)A_reg[0] = A_tile[(threadIdx.y * NUM_PER_THREAD + 0) * NUM_PER_TILE + k];A_reg[1] = A_tile[(threadIdx.y * NUM_PER_THREAD + 1) * NUM_PER_TILE + k];A_reg[2] = A_tile[(threadIdx.y * NUM_PER_THREAD + 2) * NUM_PER_TILE + k];A_reg[3] = A_tile[(threadIdx.y * NUM_PER_THREAD + 3) * NUM_PER_TILE + k];// 用FLOAT4加载B的4个元素(一行)FLOAT4(B_reg[0]) = FLOAT4(B_tile[k * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]);// 计算4x4外积for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}
}
- 关键优化点:
- 寄存器缓存:
A_reg
和B_reg
缓存共享内存数据 - 外积计算:4 元素列向量 x 4 元素行向量 ➡ 4x4 矩阵
- 向量化加载:
B_reg
使用FLOAT4
一次性加载 4 个元素
- 寄存器缓存:
- 计算过程:
- 对每个 k(0~63):
- 从
A_tile
加载一列 4 个元素到 A_reg - 从
B_tile
加载一行 4 个元素到 B_reg(用FLOAT4
) - 计算 4x4 外积并累加到 sum
- 从
- 对每个 k(0~63):
整个过程如下图所示(以 thread(0, 0)
线程为例):
每个线程处理 A_tile
四行数据与 B_tile
四列数据相乘,每次加载 A_tile
四个数据和 B_tile
四个数据到寄存器,其中 B_tile
的四个数据内存连续,因此可以使用 float4 向量化加载,循环 NUM_PER_TILE
次
5.4 第二次同步
__syncthreads();
- 确保所有线程完成当前块的计算后再加载下一块
6. 结果写回
float* C_start = C + blockIdx.y * NUM_PER_TILE * N + blockIdx.x * NUM_PER_TILE;
for(int i = 0; i < NUM_PER_THREAD; ++i){FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(sum[i * NUM_PER_THREAD]);
}
- 写回优化:
- 使用
FLOAT4
一次性写回 4 个结果 - 每个线程写回 4x4=16 个结果元素
- 写回位置计算考虑了线程块和线程的索引
- 使用
和 v5 版本相比的差异如下:
- 分块大小
- 之前:
BLOCK_SIZE = 16
- 现在:
NUM_PER_TILE = 64
- 之前:
- 线程组织
- 之前:4x46=64 线程/块
- 现在:16x16=256 线程/块
- 外积规模
- 之前:2x2 外积
- 现在:4x4 外积
v6 版本的实现通过更大的分块、更大的线程和更大的外积规模,进一步提高了计算效率和内存访问信息,此外通过 float4 向量化加载 B 矩阵和写回 C 矩阵进一步提高了吞吐量
性能和带宽测试结果如下:
优化手段 | 矩阵维度 | Grid | Block | 耗时(us) | Memory Throughout(%) | DRAM Throughout(%) |
---|---|---|---|---|---|---|
v0_global_memory | 512x512 | (32,32) | (16,16) | 471.78 | 96.94 | 1.56 |
v1_shared_memory | 256x256 | (16,16) | (16,16) | 82.11 | 78.92 | 1.84 |
v2_shared_memory_sliding_window | 512x512 | (32,32) | (16,16) | 362.50 | 94.45 | 7.05 |
v3_increase_work_of_per_thread | 512x512 | (16,16) | (16,16) | 204.26 | 84.01 | 3.64 |
v4_using_float4 | 512x512 | (32,32) | (4,16) | 209.60 | 91.99 | 3.44 |
v5_register_outer_product | 512x512 | (32,32) | (4,16) | 206.18 | 79.10 | 3.50 |
v6_register_outer_product_float4 | 512x512 | (8,8) | (16,16) | 84.99 | 60.38 | 7.28 |
3. 优化技巧7:global memory转置再存放shared memory
在 v6 版本中 A_tile
从共享内存到寄存器的加载我们并没有使用 float4,因为它们之间的内存并不是连续的。那这里我们可以考虑在将 A 矩阵存入 shared memory 之前做一次转置,这样就可以也使用 float4 来处理 A_tile
,如下图所示:
相比于之前的版本(v6)这里我们考虑在将矩阵 A 的元素从 global memory 加载到 shared memory 时做一个转置,这样我们在做外积计算时就可以直接取 A_tile
中的连续 4 个元素
加载流程如下:
这个实现方式就是借用 4 个临时的寄存器来完成转置操作,首先通过 float4 向量化读取 A 中的 4 个元素并存储在临时的 4 个寄存器中,接着将 4 个寄存器的值按照转置的方式填充到 A_tile
共享内存中,然后依次循环其他加载的元素,最终它相当于把之前 A_tile
整个给转置过来了
计算流程如下:
这个实现在外积计算时会相对简单些,因为 A_tile
是转置存储的,因此我们现在完全可以像加载 B_tile
一样通过 float4 来加载 A_tile
了,所以在上图中我们可以清晰的看到 A_tile
和 B_tile
的加载相同
代码如下:
template<unsigned int NUM_PER_TILE, unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v7_A_smen_transpose(float* A, float* B, float* C, const int M, const int N, const int K){int row = (blockIdx.y * blockDim.y + threadIdx.y) * NUM_PER_THREAD;int col = (blockIdx.x * blockDim.x + threadIdx.x) * NUM_PER_THREAD;extern __shared__ float shared_mem[];float* A_tile = shared_mem;float* B_tile = shared_mem + NUM_PER_TILE * NUM_PER_TILE;float A_reg[NUM_PER_THREAD] = {0.0f};float B_reg[NUM_PER_THREAD] = {0.0f};float A_load_reg[NUM_PER_THREAD] = {0.0f};float sum[NUM_PER_THREAD * NUM_PER_THREAD] = {0.0f};for(int k_base = 0; k_base < K; k_base += NUM_PER_TILE){for(int i = 0; i < NUM_PER_THREAD; ++i){// col-major load A_tile from global memory to shared memoryint a_col = k_base + threadIdx.x * NUM_PER_THREAD;FLOAT4(A_load_reg[0]) = FLOAT4(A[(row + i) * K + a_col]);A_tile[(threadIdx.x * NUM_PER_THREAD + 0) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[0];A_tile[(threadIdx.x * NUM_PER_THREAD + 1) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[1];A_tile[(threadIdx.x * NUM_PER_THREAD + 2) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[2];A_tile[(threadIdx.x * NUM_PER_THREAD + 3) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[3];// load B_tile from global memory to shared memoryint b_row = k_base + threadIdx.y * NUM_PER_THREAD;FLOAT4(B_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(B[(b_row + i) * N + col]);}__syncthreads();// use register to compute the sum of A_tile * B_tilefor(int k = 0; k < NUM_PER_TILE; ++k){FLOAT4(A_reg[0]) = FLOAT4(A_tile[k * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD]);FLOAT4(B_reg[0]) = FLOAT4(B_tile[k * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]);for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}}__syncthreads(); }// write the result to Cfloat* C_start = C + blockIdx.y * NUM_PER_TILE * N + blockIdx.x * NUM_PER_TILE;for(int i = 0; i < NUM_PER_THREAD; ++i){FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(sum[i * NUM_PER_THREAD]);}
}
v7 版本最核心的改进是对 A_tile
进行转置存储,使得从共享内存加载到寄存器时也能使用 FLOAT4
向量化加载,这种优化带来了以下关键变化:
A_tile
内存布局重构:从行主序改为列主序- 加载模式改变:使用
FLOAT4
同时加载 A 和 B 的数据 - 新增中间寄存器:
A_load_reg
用于临时存储转置数据
关键代码分析如下:(from ChatGPT)
1. 新增寄存器变量
float A_load_reg[NUM_PER_THREAD] = {0.0f}; // 新增的临时寄存器
- 作用:临时存储从全局内存加载的 A 矩阵数据,用于转置写入共享内存
- 必要性:实现从行优先到列优先的布局转换
2. A_tile
加载逻辑重构(核心变化)
// v6版本(原始行主序加载):
FLOAT4(A_tile[(threadIdx.y * NUM_PER_THREAD + i) * NUM_PER_TILE + threadIdx.x * NUM_PER_THREAD]) =
FLOAT4(A[(row + i) * K + a_col]);// v7版本(转置列主序加载):
FLOAT4(A_load_reg[0]) = FLOAT4(A[(row + i) * K + a_col]);
A_tile[(threadIdx.x * NUM_PER_THREAD + 0) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[0];
A_tile[(threadIdx.x * NUM_PER_THREAD + 1) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[1];
A_tile[(threadIdx.x * NUM_PER_THREAD + 2) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[2];
A_tile[(threadIdx.x * NUM_PER_THREAD + 3) * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD + i] = A_load_reg[3];
- 转置操作解析:
- 1. 用
FLOAT4
从全局内存连续加载 4 个元素到A_load_reg
- 2. 将这些元素分散存储到共享内存的不同位置,实现转置
- 3. 存储模式:
(threadIdx.x * NUM_PER_THREAD + n)
决定列,(threadIdx.y * NUM_PER_THREAD + i)
决定行
- 1. 用
- 内存布局对比:
- v6:
A_tile
按行存储,同一行的元素在内存中连续 - v7:
A_tile
按列存储,同一列的元素在内存中连续
- v6:
整个加载过程如下图所示(以 thread(0, 0)
为例):
数据从 global memory 加载到寄存器中和之前保持一致,都是行主序加载。但是将数据从寄存器加载到 shared memory 中时是转置加载的,也就是列主序加载,这样我们在后续计算时也可以使用 float4 来加载 A_tile
3. 计算核心的优化(关键改进)
// v6版本(标量加载A):
A_reg[0] = A_tile[(threadIdx.y * NUM_PER_THREAD + 0) * NUM_PER_TILE + k];
// ...(加载4个标量)// v7版本(FLOAT4加载A):
FLOAT4(A_reg[0]) = FLOAT4(A_tile[k * NUM_PER_TILE + threadIdx.y * NUM_PER_THREAD]);
- 优化效果:原来需要 4 次单独加载,现在 1 次
FLOAT4
完成
我们举个简单的例子来说明下这一过程
假设矩阵 A 和矩阵 B 都是 4x4 大小,则二者相乘计算如下图所示:
其中 sum[0][0]
的计算结果如图中所示,A 的一行和 B 的一列相乘
假设矩阵 A 和矩阵 B 都加载到了 A_tile
和 B_tile
中,且 A_tile
是列主序加载,则此时 sum[0][0]
的计算如下图所示:
由于 A_tile
是列主序存储,因此可以和 B_tile
一样通过 float4 向量化加载,加载到寄存器后由于是外积相乘,因此每次循环恰好计算相应元素的乘积,最终的结果也和前面保持一致
性能和带宽测试结果如下:
优化手段 | 矩阵维度 | Grid | Block | 耗时(us) | Memory Throughout(%) | DRAM Throughout(%) |
---|---|---|---|---|---|---|
v0_global_memory | 512x512 | (32,32) | (16,16) | 471.78 | 96.94 | 1.56 |
v1_shared_memory | 256x256 | (16,16) | (16,16) | 82.11 | 78.92 | 1.84 |
v2_shared_memory_sliding_window | 512x512 | (32,32) | (16,16) | 362.50 | 94.45 | 7.05 |
v3_increase_work_of_per_thread | 512x512 | (16,16) | (16,16) | 204.26 | 84.01 | 3.64 |
v4_using_float4 | 512x512 | (32,32) | (4,16) | 209.60 | 91.99 | 3.44 |
v5_register_outer_product | 512x512 | (32,32) | (4,16) | 206.18 | 79.10 | 3.50 |
v6_register_outer_product_float4 | 512x512 | (8,8) | (16,16) | 84.99 | 60.38 | 7.28 |
v7_A_smem_transpose | 512x512 | (8,8) | (16,16) | 118.21 | 65.85 | 5.39 |
4. 优化技巧8:使用double buffer加速矩阵乘法
之前的版本中我们的 shared memory 只有一组,如下图所示:
这里考虑使用两组,也就是 double buffer 的优化策略,其中一组先预填充,另一组异步加载,然后对预填充的缓冲区计算,这样可以确保加载和计算重叠,有助于延迟隐藏,具体的实现流程我们还是来看代码慢慢讲解吧
代码如下:
template<unsigned int BLOCK_SIZE_M,unsigned int BLOCK_SIZE_N,unsigned int BLOCK_SIZE_K,unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v8_double_buffer(float* A, float* B, float* C, const int M, const int N, const int K){float* A_start = A + blockIdx.y * BLOCK_SIZE_M * K;float* B_start = B + blockIdx.x * BLOCK_SIZE_N;// double bufferextern __shared__ float shared_mem[];int A_tile_per_buffer_size = BLOCK_SIZE_K * BLOCK_SIZE_M;int B_tile_per_buffer_size = BLOCK_SIZE_K * BLOCK_SIZE_N;float* A_tile = shared_mem;float* B_tile = shared_mem + 2 * A_tile_per_buffer_size;float A_reg[NUM_PER_THREAD] = {0.0f};float B_reg[NUM_PER_THREAD] = {0.0f};float A_load_reg[4] = {0.0f};float sum[NUM_PER_THREAD * NUM_PER_THREAD] = {0.0f};// re-arrange the layoutint tid = threadIdx.y * blockDim.x + threadIdx.x; // 0~256int A_tile_tx = tid % (BLOCK_SIZE_K / 4); // 0~1int A_tile_ty = tid / (BLOCK_SIZE_K / 4); // 0~127int B_tile_tx = tid % (BLOCK_SIZE_N / 4); // 0~31int B_tile_ty = tid / (BLOCK_SIZE_N / 4); // 0~7// prefetch first tileFLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4]);A_tile[(A_tile_tx * 4 + 0) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[0]; A_tile[(A_tile_tx * 4 + 1) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[1]; A_tile[(A_tile_tx * 4 + 2) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[2]; A_tile[(A_tile_tx * 4 + 3) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[3]; FLOAT4(B_tile[B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4]) = FLOAT4(B_start[B_tile_ty * N + B_tile_tx * 4]);__syncthreads();int buffer_idx = 1;for(int k_base = BLOCK_SIZE_K; k_base < K; k_base += BLOCK_SIZE_K){// prefetch next tileFLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4 + k_base]);A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + 0) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[0]; A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + 1) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[1]; A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + 2) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[2]; A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + 3) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[3]; FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4]) = FLOAT4(B_start[(B_tile_ty + k_base) * N + B_tile_tx * 4]);// toggle buffer indexbuffer_idx = buffer_idx ^ 1;// compute current tilefor(int k = 0; k < BLOCK_SIZE_K; ++k){// load A_tile and B_tile from shared memory to registerFLOAT4(A_reg[0]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD]);FLOAT4(A_reg[4]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD + 4]);FLOAT4(B_reg[0]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD]);FLOAT4(B_reg[4]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD + 4]);// use register to compute the sum of A_tile * B_tilefor(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}}__syncthreads();}buffer_idx = buffer_idx ^ 1;for(int k = 0; k < BLOCK_SIZE_K; ++k){// compute the last tileFLOAT4(A_reg[0]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD]);FLOAT4(A_reg[4]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD + 4]);FLOAT4(B_reg[0]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD]);FLOAT4(B_reg[4]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD + 4]);for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}} // write the result to Cfloat* C_start = C + blockIdx.y * BLOCK_SIZE_M * N + blockIdx.x * BLOCK_SIZE_N;for(int i = 0; i < NUM_PER_THREAD; ++i){FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(sum[i * NUM_PER_THREAD]);FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD + 4]) = FLOAT4(sum[i * NUM_PER_THREAD + 4]);}
}
下面是该代码的详细分析:(from ChatGPT)
1. 核函数签名与参数
template<unsigned int BLOCK_SIZE_M,unsigned int BLOCK_SIZE_N,unsigned int BLOCK_SIZE_K,unsigned int NUM_PER_THREAD>
__global__ void cuda_sgemm_v8_double_buffer(float* A, float* B, float* C, const int M, const int N, const int K)
- 模板参数
BLOCK_SIZE_M = 128
:每个线程块处理的 M 维度大小BLOCK_SIZE_N = 128
:每个线程块处理的 N 维度大小BLOCK_SIZE_K = 8
:每个线程块处理的 K 维度大小NUM_PER_THREAD
:每个线程 x 方向和 y 方向分别处理的元素数量
- 启动参数:
block(16, 16)
:每个线程块包含 256 个线程grid(4, 4)
:整个网格包含 16 个线程块shared_mem_size
:双缓冲区所需共享内存
2. 输入矩阵起始指针偏移
float* A_start = A + blockIdx.y * BLOCK_SIZE_M * K;
float* B_start = B + blockIdx.x * BLOCK_SIZE_N;
A_start
:当前 block 块处理的 A 矩阵的起始指针B_start
:当前 block 块处理的 B 矩阵的起始指针
关于索引的计算,每个 block 要处理的是 A 中子矩阵 BLOCK_SIZE_M * K 与 B 中子矩阵 K * BLOCK_SIZE_N 的乘积。因此 A 矩阵的起始指针是 blockIdx.y * (BLOCK_SIZE_M * K)
,其中 blockIdx.y
表示 y 方向上当前 block 的索引;B 矩阵的起始指针是 blockIdx.x * BLOCK_SIZE_N
,其中 blockIdx.x
表示 x 方向上当前 block 的索引
如上图所示,block(2, 1) 处理的 A 矩阵和 B 矩阵的起始指针是 ❌ 的位置
3. 共享内存分配(双缓冲)
extern __shared__ float shared_mem[];
int A_tile_per_buffer_size = BLOCK_SIZE_K * BLOCK_SIZE_M; // 8*128=1024
int B_tile_per_buffer_size = BLOCK_SIZE_K * BLOCK_SIZE_N; // 8*128=1024
float* A_tile = shared_mem; // 缓冲区0的A_tile
float* B_tile = shared_mem + 2 * A_tile_per_buffer_size; // 缓冲区0的B_tile
- 分配共享内存用于矩阵 A 的两个分块缓冲区(
A_tile
)和矩阵 B 的两个分块缓冲区(B_tile
) - 总共四个 tile 分块,两个用于预加载下一个块(prefetch),两个用于计算当前块
共享内存布局为:
[A_tile_buffer_0, A_tile_buffer_1, B_tile_buffer_0, B_tile_buffer_1]
4. 寄存器初始化
float A_reg[NUM_PER_THREAD] = {0.0f};
float B_reg[NUM_PER_THREAD] = {0.0f};
float A_load_reg[4] = {0.0f};
float sum[NUM_PER_THREAD * NUM_PER_THREAD] = {0.0f};
- 寄存器用于局部存储:加载小块、计算用值、累积结果
5. 线程布局重排(线程索引与 tile 索引转换)
int tid = threadIdx.y * blockDim.x + threadIdx.x; // 0~255
int A_tile_tx = tid % (BLOCK_SIZE_K / 4); // BLOCK_SIZE_K=8 → 0~1
int A_tile_ty = tid / (BLOCK_SIZE_K / 4); // 0~127
int B_tile_tx = tid % (BLOCK_SIZE_N / 4); // BLOCK_SIZE_N=128 → 0~31
int B_tile_ty = tid / (BLOCK_SIZE_N / 4); // 0~7
- 将 256 个线程重新划分为不同的加载工作组(如下图所示)
A_tile
加载:128 个线程组(每个组 2 个线程)B_tile
加载:8 个线程组(每个组 32 个线程)
6. 预取一个数据块
// prefetch first tile
FLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4]);
A_tile[(A_tile_tx * 4 + 0) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[0];
A_tile[(A_tile_tx * 4 + 1) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[1];
A_tile[(A_tile_tx * 4 + 2) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[2];
A_tile[(A_tile_tx * 4 + 3) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[3]; FLOAT4(B_tile[B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4]) = FLOAT4(B_start[B_tile_ty * N + B_tile_tx * 4]);__syncthreads();
- 使用
FLOAT4
宏从 global memory 读取 A A_tile
采用转置存储(列主序),同理B_tile
也进行加载,实现计算阶段的内存访问连续性- 等待所有线程完成预加载
这里的重点是各个索引的计算,下面我们简要分析下:(from ChatGPT)
6.1 A_tile
预取的索引计算
全局内存加载索引
FLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4]);
A_start
:当前线程块对应的 A 矩阵起始指针- 索引分解:
A_tile_ty
:范围 0~127A_tile_tx
:范围 0~1
- 实际访问:
- 每个线程处理 4 个连续元素(
FLOAT4
) - 访问位置:
A_start + (A_tile_ty * K) + (A_tile_tx * 4)
- 相当于:
- 在行方向:
A_tile_ty
(0~127) - 在列方向:
A_tile_tx * 4
(0 或 4)
- 在行方向:
- 每个线程处理 4 个连续元素(
- 线程分工:
- 256 个线程分成 128 组(
A_tile_ty
= 0~127),每组 2 个线程(A_tile_tx
= 0~1) - 每组线程负责加载 8 个连续元素(2 个 FLOAT4)
- 256 个线程分成 128 组(
我们以第一个线程块 block(0, 0)
为例来讲解 A_tile
预取第一个数据块部分的索引计算,如下图所示:
图中灰色区域就是 block(0, 0)
线程块需要加载的数据,总共 128x8 大小的元素数量,一个 block 包含 256 个线程,每个线程负责加载 4 个元素。线程索引的转换在步骤 5 中来完成的,(16, 16)➡(128, 2)
共享内存存储索引(转置关键)
A_tile[(A_tile_tx * 4 + n) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[n];
- 存储布局:
- 将全局内存的行主序转为共享内存的列主序
- 公式:
(col * BLOCK_SIZE_M) + row
- 索引分解:
A_tile_tx * 4 + n
:列索引(0~7)A_tile_ty
:行索引(0~127)
- 转置效果:
- 全局内存中的行
A_tile_ty
变为共享内存中的列A_tile_ty
- 全局内存中的列
A_tile_tx * 4 + n
变为共享内存中的行A_tile_tx * 4 + n
- 全局内存中的行
整个加载过程如下图所示:
6.2 B_tile
预取的索引计算
全局内存加载索引
... = FLOAT4(B_start[B_tile_ty * N + B_tile_tx * 4]);
B_start
:当前线程块对应的 B 矩阵的起始指针- 索引分解:
B_tile_ty
:范围 0~7B_tile_tx
:范围 0~31
- 实际访问:
- 每个线程处理 4 个连续元素(
FLOAT4
) - 访问位置:
B_start + (B_tile_ty * N) + (B_tile_tx * 4)
- 相当于:
- 在行方向:
B_tile_ty
(0~7) - 在列方向:
B_tile_tx * 4
(0~124)
- 在行方向:
- 每个线程处理 4 个连续元素(
- 线程分工:
- 256 个线程分成 8 组(
B_tile_ty
= 0~7),每组 32 个线程(B_tile_tx
= 0~31) - 每组线程负责加载 128 个连续元素(32 个
FLOAT4
)
- 256 个线程分成 8 组(
我们以第一个线程块 block(0, 0)
为例来讲解 B_tile
预取第一个数据块部分的索引计算,如下图所示:
图中灰色区域就是 block(0, 0)
线程块需要加载的数据,总共 8x128 大小的元素数量,一个 block 包含 256 个线程,每个线程负责加载 4 个元素。线程索引的转换在步骤 5 中来完成的,(16, 16)➡(32, 8)
共享内存存储索引
B_tile[B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4] = ...
- 存储布局:
- 保持行主序(不转置)
- 公式:
(row * BLOCK_SIZE_N + col)
- 索引分解:
B_tile_ty
:行索引(0~7)B_tile_tx
:列索引(0~124)
整个加载过程如下图所示:
由于我们是行主序加载存储,因此索引计算方式相比 A_tile
来说更加简单
7. 主计算循环(双缓冲核心)
int buffer_idx = 1; // 初始缓冲区索引
for(int k_base = BLOCK_SIZE_K; k_base < K; k_base += BLOCK_SIZE_K){// 1. 预取下一块到非活动缓冲区FLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4 + k_base]);A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + 0) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[0];// ...(写入4个元素)FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4]) = FLOAT4(B_start[(B_tile_ty + k_base) * N + B_tile_tx * 4]);// 2. 切换缓冲区索引buffer_idx = buffer_idx ^ 1; // 0↔1切换// 3. 计算当前块(使用另一个缓冲区)for(int k = 0; k < BLOCK_SIZE_K; ++k){FLOAT4(A_reg[0]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD]);FLOAT4(A_reg[4]) = FLOAT4(/*... +4*/); // 加载8个元素FLOAT4(B_reg[0]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD]);FLOAT4(B_reg[4]) = FLOAT4(/*... +4*/);// 4. 8x8外积计算for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}}}__syncthreads();
}
- 流水线设计:
- 当计算在使用 buffer0 时,异步加载下一块到 buffer1
- 下次迭代切换缓冲区,计算 buffer1 同时加载到 buffer0
- 优势:
- 计算与内存传输重叠
- 隐藏内存访问延迟
- 提高计算单元利用率
7.1 双缓冲索引管理
int buffer_idx = 1; // 初始缓冲区索引
for(int k_base = BLOCK_SIZE_K; k_base < K; k_base += BLOCK_SIZE_K){// ...预取和计算代码...buffer_idx = buffer_idx ^ 1; // 缓冲区切换
}
- 初始值:
buffer_idx = 1
(因为第 0 个缓冲区已在预取阶段填充) - 切换逻辑:
buffer_idx ^ 1
在 0 和 1 之间切换 - 双缓冲工作流程:
- 计算使用
buffer_idx
指向的缓冲区 - 同时预取数据到
buffer_idx ^ 1
指向的缓冲区 - 每次迭代后切换缓冲区
- 计算使用
7.2 A_tile
预取索引(下一块)
FLOAT4(A_load_reg[0]) = FLOAT4(A_start[A_tile_ty * K + A_tile_tx * 4 + k_base]);
A_tile[buffer_idx * A_tile_per_buffer_size + (A_tile_tx * 4 + n) * BLOCK_SIZE_M + A_tile_ty] = A_load_reg[n];
全局内存加载索引:
A_tile_ty * K
:行偏移(0~127 行,每行跳 K 元素)A_tile_tx * 4
:列偏移(0 或 4)k_base
:当前 K 维度的基偏移(8,16,…,504)
实际访问模式:
- 每个线程加载全局内存中相隔 K 元素的 4 个连续 float
- 整体访问模式是跨步的但合并的(coalesced)
整个加载过程如下图所示:
那其实加载与缓冲区 0 的索引计算一样,只是有一个 k_base
的偏移量,代表着处理下一个缓冲区
共享内存存储索引:
buffer_idx * A_tile_per_buffer_size
:选择缓冲区(0 或 1024)(A_tile_tx * 4 + n) * BLOCK_SIZE_M
:列主序的列计算(0~7 * 128)A_tile_ty
:行索引(0~127)
从寄存器到共享内存的加载如上图所示,值得注意的是这里我们加载的是 buffer1 缓冲区,因此有一个 buffer_idx * A_tile_per_buffer_size
的偏移量存在
7.3 B_tile
预取索引(下一块)
FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + B_tile_ty * BLOCK_SIZE_N + B_tile_tx * 4]) =
FLOAT4(B_start[N * (B_tile_ty + k_base) + B_tile_tx * 4]);
全局内存加载索引:
B_tile_ty + k_base
:行索引(0~7 + 8,16,…,504)B_tile_tx * 4
:列索引(0~124)N * row + col
:行主序访问
访问特点:
- 每个线程加载 B 矩阵中连续的 4 个元素
- 访问模式是完全连续的
整个加载过程如下图所示:
和 A_tile
一样,这里也有一个 k_base
的偏移量
共享内存存储索引
buffer_idx * B_tile_per_buffer_size
:选择缓冲区(0 或 1024)B_tile_ty * BLOCK_SIZE_N
:行偏移(0~7 * 128)B_tile_tx * 4
:列偏移(0~124)
布局特点:
- 保持行主序存储
- 与全局内存布局一致
B_tile
缓冲区 1 从寄存器到共享内存的加载过程如上图所示,同样有一个 buffer_idx * B_tile_per_buffer_size
的偏移量存在
7.4 计算阶段索引(当前块)
buffer_idx = buffer_idx ^ 1;
for(int k = 0; k < BLOCK_SIZE_K; ++k){FLOAT4(A_reg[0]) = FLOAT4(A_tile[buffer_idx * A_tile_per_buffer_size + k * BLOCK_SIZE_M + threadIdx.y * NUM_PER_THREAD]);FLOAT4(A_reg[4]) = FLOAT4(/*...+4*/);FLOAT4(B_reg[0]) = FLOAT4(B_tile[buffer_idx * B_tile_per_buffer_size + k * BLOCK_SIZE_N + threadIdx.x * NUM_PER_THREAD]);FLOAT4(B_reg[4]) = FLOAT4(/*...+4*/);// ...外积计算...
}
A_tile
加载索引
buffer_idx * A_tile_per_buffer_size
:选择缓冲区k * BLOCK_SIZE_M
:K 维度偏移(0~7 * 128)threadIdx.y * NUM_PER_THREAD
:线程在 M 维度的偏移(0~15 * 8)
关键点:
- 由于
A_tile
是转置存储的,这里实际上是按列连续访问 - 每次加载 8 个连续元素(2 个
FLOAT4
)
B_tile
加载索引
buffer_idx * B_tile_per_buffer_size
:选择缓冲区k * BLOCK_SIZE_N
:K 维度偏移(0~7 * 128)threadIdx.x * NUM_PER_THREAD
:线程在 N 维度的偏移(0~15 * 8)
关键点:
- 保持行主序访问
- 每次加载 8 个连续元素(2 个
FLOAT4
)
从共享内存到寄存器的加载过程如下图所示(以 thread(0,0)
为例):
7.5 外积计算索引
for(int i = 0; i < NUM_PER_THREAD; ++i){for(int j = 0; j < NUM_PER_THREAD; ++j){sum[i * NUM_PER_THREAD + j] += A_reg[i] * B_reg[j];}
}
A_reg
索引:i
(0~7)B_reg
索引:j
(0~7)sum
索引:i * 8 + j
(0~63)
计算模式:
- 8 元素 A 列向量 x 8 元素 B 行向量 ➡ 8x8 外积
- 结果累加到 64 个局部和寄存器中
8. 处理最后一个数据块
buffer_idx = buffer_idx ^ 1; // 切换回最后一个缓冲区
for(int k = 0; k < BLOCK_SIZE_K; ++k){// 加载并计算最后一个块FLOAT4(A_reg[0]) = FLOAT4(/*...*/);// ...(完整8x8外积计算)
}
- 与前面一样,只是不再 prefetch
9. 结果写回
float* C_start = C + blockIdx.y * BLOCK_SIZE_M * N + blockIdx.x * BLOCK_SIZE_N;
for(int i = 0; i < NUM_PER_THREAD; ++i){FLOAT4(C_start[(threadIdx.y * NUM_PER_THREAD + i) * N + threadIdx.x * NUM_PER_THREAD]) = FLOAT4(sum[i * NUM_PER_THREAD]);FLOAT4(/*...+4*/) = FLOAT4(/*...+4*/); // 写回8个元素
}
9.1 输出矩阵 C 的起始位置计算
float* C_start = C + blockIdx.y * BLOCK_SIZE_M * N + blockIdx.x * BLOCK_SIZE_N;
blockIdx.y
维度:blockIdx.y * BLOCK_SIZE_M * N
:计算当前线程块在 M 维度的偏移- 每个线程块处理
BLOCK_SIZE_M = 128
行 - 乘以
N
得到正确的行偏移量(因为 C 是行主序)
blockIdx.x
维度:blockIdx.x * BLOCK_SIZE_N
:计算当前线程块在 N 维度的偏移- 每个线程块处理
BLOCK_SIZE_N = 128
行
- 组合效果:
- 定位到当前线程块负责计算的 C 矩阵子块的起始位置,和
A_start
、B_start
类似
- 定位到当前线程块负责计算的 C 矩阵子块的起始位置,和
9.2 线程到输出位置的映射
行索引计算
threadIdx.y * NUM_PER_THREAD + i
:threadIdx.y
:线程在块内的 y 坐标(0~15)NUM_PER_THREAD = 8
:每个线程负责 8 行i
:当前迭代(0~7)- 组合效果:0~127(覆盖
BLOCK_SIZE_M = 128
)
列索引计算
threadIdx.x * NUM_PER_THREAD
:threadIdx.x
:线程在块内的 y 坐标(0~15)NUM_PER_THREAD = 8
:每个线程负责 8 列
threadIdx.x * NUM_PER_THREAD + 4
:- 额外的列偏移,用于处理每个线程的 8 列中的后 4 列
v8 版本通过双缓冲共享内存、寄存器 blocking、数据预取(prefetching)与流水线方式来提高计算效率,其中:
- 1. 双缓冲共享内存:使用两个缓冲区来重叠数据传输和计算
- 一个缓冲区用于当前计算
- 另一个缓冲区用于异步加载下一批数据
- 2. 寄存器 blocking:每个线程使用寄存器缓存多个元素
- 3. 数据预取:提前从全局内存加载下一 tile 的数据
- 4. 流水线执行
- 在计算当前分块时, 异步加载下一个分块
- 通过
buffer_idx
在 0 和 1 之间切换,实现缓冲区轮换
- 5. 性能优势
- 隐藏了全局内存访问延迟
- 计算和内存传输可以并行进行
- 减少了线程等待时间
性能和带宽测试结果如下:
优化手段 | 矩阵维度 | Grid | Block | 耗时(us) | Memory Throughout(%) | DRAM Throughout(%) |
---|---|---|---|---|---|---|
v0_global_memory | 512x512 | (32,32) | (16,16) | 471.78 | 96.94 | 1.56 |
v1_shared_memory | 256x256 | (16,16) | (16,16) | 82.11 | 78.92 | 1.84 |
v2_shared_memory_sliding_window | 512x512 | (32,32) | (16,16) | 362.50 | 94.45 | 7.05 |
v3_increase_work_of_per_thread | 512x512 | (16,16) | (16,16) | 204.26 | 84.01 | 3.64 |
v4_using_float4 | 512x512 | (32,32) | (4,16) | 209.60 | 91.99 | 3.44 |
v5_register_outer_product | 512x512 | (32,32) | (4,16) | 206.18 | 79.10 | 3.50 |
v6_register_outer_product_float4 | 512x512 | (8,8) | (16,16) | 84.99 | 60.38 | 7.28 |
v7_A_smem_transpose | 512x512 | (8,8) | (16,16) | 118.21 | 65.85 | 5.39 |
v8_double_buffer | 512x512 | (4,4) | (16,16) | 135.71 | 31.56 | 4.44 |
OK,以上就是 sgemm 各种优化的代码实现了
结语
这篇文章中 sgemm 的一些优化技巧相比上篇文章来说复杂一些,博主经常被其中的索引计算搞破防,曾一度想放弃,不过静下心来画画图慢慢思考总是能理解的。在计算时可以先定位到当前 block 要处理的起始元素位置,然后思考 block 中每个 thread 负责处理几个元素,都是怎么处理的,行和列索引分别是多少,这样会相对简单一些
OK,以上就是整篇文章的全部内容了
总的来说,跟随 up 主一步步来实现还是能理解的,大家感兴趣的可以多看看 up 主的视频,还是非常不错的🤗
下载链接
- Sgemm 矩阵乘法代码下载链接【提取码:1234】
参考
- 【CUDA】Sgemm单精度矩阵乘法(已完结~)
- https://github.com/tpoisonooo/how-to-optimize-gemm/cuda
- 深入浅出GPU优化系列:GEMM优化(一)
- [施工中] CUDA GEMM 理论性能分析与 kernel 优化
- cuda 入门的正确姿势:how-to-optimize-gemm
- CUDA 矩阵乘法终极优化指南
- CUDA实现矩阵乘法的8种优化策略编程介绍
- https://chatgpt.com/