CUDA 矩阵分块乘法
一图胜千言,如下图所示,现在要通过矩阵分块的方式计算矩阵 A
乘以矩阵 B
的结果(记为矩阵 C
),假设:
A
矩阵的分块是:A11, A12, A21, A22, A31, A32
;B
矩阵的分块是:B11, B12, B13, B21, B22, B23
;C
矩阵的分块是:C11, C12, C13, C21, C22, C23, C31, C32, C33
;- 每个分块的大小是:
BLOCK_SIZE * BLOCK_SIZE
。
与之配套的核函数线程布局是:
9
个block
线程块,即Bk11, Bk12, Bk13, Bk21, Bk22, Bk23, Bk31, Bk32, Bk33
,与C
矩阵的9
个分块一一对应;- 每个
block
线程块的大小是:BLOCK_SIZE * BLOCK_SIZE
。
从上面的 结果矩阵分块 和 线程分块 的对应关系可以看到,每个 block
线程块负责产生 C
矩阵中的一个分块的计算结果,例如:Bk33
线程块负责计算 C33
分块的结果,由矩阵分块乘法的原理可知,C33 = A31 x B13 + A32 x B23
,也就是说:Bk33
线程块需要做的事情是:
- 读入
A31、B13
,进行矩阵乘法; - 读入
A32、B23
,进行矩阵乘法; - … //如果涉及更多的分块,重复上面的操作;
- 将上述矩阵乘法得到的所有矩阵按元素对应相加 //上述步骤中在每一步做完矩阵乘法之后,可以直接在
C33
分块上进行累加。
下面的代码来自 cuda-samples:
/*** 通过矩阵分块的方法计算矩阵乘法: C = A * B* wA 是 A 矩阵的宽度,wB 是 B 矩阵的宽度*/
template <int BLOCK_SIZE> __global__ void MatrixMulCUDA(float *C, float *A, float *B, int wA, int wB)
{// block 索引,假设当前是 Bk33 线程块,则 bx = 2,by = 2int bx = blockIdx.x;int by = blockIdx.y;// thread 索引int tx = threadIdx.x;int ty = threadIdx.y;// 当前 block 为了生成 C33 分块的计算结果,需要处理的 A 矩阵的第一个分块,即 A31 分块的第一个元素的地址,这里的 by 是 2int aBegin = wA * BLOCK_SIZE * by;// 当前 block 需要处理的 A 矩阵的最后一个分块第一行的结束位置,即 A32 分块第一行最后一个元素的地址int aEnd = aBegin + wA - 1;// A 矩阵在 x 方向上,连续两个分块的第一个元素的间隔,横着是 x 方向,竖着是 y 方向int aStep = BLOCK_SIZE;// 当前 block 为了生成 C33 分块的计算结果,需要处理的 B 矩阵的第一个分块,即 B13 分块的第一个元素的地址,这里的 bx 是 2int bBegin = BLOCK_SIZE * bx;// B 矩阵在 y 方向上,连续两个分块的第一个元素的间隔int bStep = BLOCK_SIZE * wB;// C33 分块每个元素的初始值,初始为 0,用于后续累加float Csub = 0;// 遍历当前 block(Bk33)需要处理的所有 A 矩阵分块(A31、A32)和 B 矩阵分块(B13、B23)// 上面计算出了 A 矩阵和 B 矩阵的分块步进幅度,因此很容易在遍历过程中拿到对应分块的起始地址for (int a = aBegin, b = bBegin; a <= aEnd; a += aStep, b += bStep) {// 声明用于存储 A 矩阵分块的共享内存,用于一个 block 内所有线程共享数据__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];// 声明用于存储 B 矩阵分块的共享内存,用于一个 block 内所有线程共享数据__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];// 发动一个 block 内的所有线程将 global memory 里的数据加载到 shared memory// 即加载 A 矩阵和 B 矩阵的对应分块As[ty][tx] = A[a + wA * ty + tx];Bs[ty][tx] = B[b + wB * ty + tx];// 等待一个 block 内的所有线程都完成数据加载操作__syncthreads();// 发动一个 block 内的所有线程对 A 矩阵和 B 矩阵的对应分块进行矩阵乘法操作// 对于某一个线程来说,需要完成 A 矩阵分块的某一行乘以 B 矩阵分块的某一列的操作,并求和// 针对多个分块矩阵,对求和结果进行累加
#pragma unrollfor (int k = 0; k < BLOCK_SIZE; ++k) {Csub += As[ty][k] * Bs[k][tx];}// 等待一个 block 内的所有线程都完成行乘列的操作__syncthreads();}// 将 A31 x B13 + A32 x B23 的结果写入 C33 分块的对应位置,// 这里的索引计算是难点,// C 矩阵的宽度等于 B 矩阵的宽度,所以先计算 y 方向跳过的元素个数:wB * BLOCK_SIZE * by,// 再计算 x 方向跳过的元素个数:BLOCK_SIZE * bx,// 再计算 C33 分块内跳过的元素个数:wB * ty + tx,// 再将上述 3 个式子的结果相加,得到 C33 分块中某一个元素相对于 C 矩阵的偏移int c = wB * BLOCK_SIZE * by + BLOCK_SIZE * bx;C[c + wB * ty + tx] = Csub;
}