triton学习笔记7: GEMM相关
这是之前的学习笔记
- triton puzzles part1
- triton puzzles part2
- triton puzzles part3
- triton tutorials part1
- triton tutorials: part2
- triton tutorails: part3
这是triton tutorials里最后一篇关于GEMM的系列了
GEMM的知识可以参考这篇,写的非常详细具体https://zhuanlan.zhihu.com/p/703256080
Group GEMM
from typing import Optional
import torchimport triton
import triton.language as tlDEVICE = triton.runtime.driver.active.get_active_torch_device()def is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma():return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def num_sms():if is_cuda():return torch.cuda.get_device_properties("cuda").multi_processor_countreturn 148@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32,'NUM_SM': 84,}),triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32,'NUM_SM': 128,}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 64,'BLOCK_SIZE_K': 32,'NUM_SM': 84,}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 64,'BLOCK_SIZE_K': 32,'NUM_SM': 128,}),triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 64,'NUM_SM': num_sms(),}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 64,'NUM_SM': num_sms(),}),],key=['group_size'],
)
@triton.jit
def grouped_matmul_kernel(# device tensor of matrices pointersgroup_a_ptrs,group_b_ptrs,group_c_ptrs,# device tensor of gemm sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemmgroup_gemm_sizes,# device tensor of leading dimension sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemmg_lds,# number of gemmsgroup_size,# number of virtual SMNUM_SM: tl.constexpr,# tile sizesBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,
):tile_idx = tl.program_id(0)last_problem_end = 0for g in range(group_size):# get the gemm size of the current problemgm = tl.load(group_gemm_sizes + g * 3)gn = tl.load(group_gemm_sizes + g * 3 + 1)gk = tl.load(group_gemm_sizes + g * 3 + 2)num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)num_tiles = num_m_tiles * num_n_tiles# iterate through the tiles in the current gemm problemwhile (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):# pick up a tile from the current gemm problemk = gklda = tl.load(g_lds + g * 3)ldb = tl.load(g_lds + g * 3 + 1)ldc = tl.load(g_lds + g * 3 + 2)a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))# figure out tile coordinatestile_idx_in_gemm = tile_idx - last_problem_endtile_m_idx = tile_idx_in_gemm // num_n_tilestile_n_idx = tile_idx_in_gemm % num_n_tiles# do regular gemm hereoffs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)offs_k = tl.arange(0, BLOCK_SIZE_K)a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):# hint to Triton compiler to do proper loop pipeliningtl.multiple_of(a_ptrs, [16, 16])tl.multiple_of(b_ptrs, [16, 16])# assume full tile for nowa = tl.load(a_ptrs)b = tl.load(b_ptrs)accumulator += tl.dot(a, b)a_ptrs += BLOCK_SIZE_Kb_ptrs += BLOCK_SIZE_K * ldbc = accumulator.to(tl.float16)offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]# assumes full tile for nowtl.store(c_ptrs, c)# go to the next tile by advancing NUM_SMtile_idx += NUM_SM# get ready to go to the next gemm problemlast_problem_end = last_problem_end + num_tilesdef group_gemm_fn(group_A, group_B):assert len(group_A) == len(group_B)group_size = len(group_A)A_addrs = []B_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = group_A[i]B = group_B[i]assert A.shape[1] == B.shape[0]M, K = A.shapeK, N = B.shapeC = torch.empty((M, N), device=DEVICE, dtype=A.dtype)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]# note these are device tensorsd_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)# we use a fixed number of CTA, and it's auto-tunablegrid = lambda META: (META['NUM_SM'], )grouped_matmul_kernel[grid](d_a_ptrs,d_b_ptrs,d_c_ptrs,d_g_sizes,d_g_lds,group_size,)return group_Ctma_configs = [triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, 'BLOCK_SIZE_K' : BK}, num_stages=s, num_warps=w) \for BM in [128]\for BN in [128, 256]\for BK in [64, 128]\for s in ([3, 4])\for w in [4, 8]\
]@triton.autotune(tma_configs,key=['group_a_ptrs', 'group_b_ptrs', 'gropup_c_ptrs', 'group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(# device tensor of matrices pointersgroup_a_ptrs,group_b_ptrs,group_c_ptrs,# device tensor of gemm sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemmgroup_gemm_sizes,# device tensor of leading dimension sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemmg_lds,# number of gemmsgroup_size,# number of virtual SMNUM_SM: tl.constexpr,# tile sizesBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,# is the output FP8 or FP16FP8: tl.constexpr,
):dtype = tl.float8e4nv if FP8 else tl.float16tile_idx = tl.program_id(0)last_problem_end = 0for g in range(group_size):# get the gemm size of the current problemgm = tl.load(group_gemm_sizes + g * 3)gn = tl.load(group_gemm_sizes + g * 3 + 1)gk = tl.load(group_gemm_sizes + g * 3 + 2)num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)num_tiles = num_m_tiles * num_n_tilesif tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:# pick up a tile from the current gemm problemlda = tl.load(g_lds + g * 3)ldb = tl.load(g_lds + g * 3 + 1)ldc = tl.load(g_lds + g * 3 + 2)a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))a_desc = tl.make_tensor_descriptor(a_ptr,shape=[gm, gk],strides=[lda, 1],block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],)b_desc = tl.make_tensor_descriptor(b_ptr,shape=[gn, gk],strides=[ldb, 1],block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],)c_desc = tl.make_tensor_descriptor(c_ptr,shape=[gm, gn],strides=[ldc, 1],block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],)# iterate through the tiles in the current gemm problemwhile (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):k = gk# figure out tile coordinatestile_idx_in_gemm = tile_idx - last_problem_endtile_m_idx = tile_idx_in_gemm // num_n_tilestile_n_idx = tile_idx_in_gemm % num_n_tiles# do regular gemm hereoffs_am = tile_m_idx * BLOCK_SIZE_Moffs_bn = tile_n_idx * BLOCK_SIZE_Naccumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])accumulator += tl.dot(a, b.T)offs_cm = tile_m_idx * BLOCK_SIZE_Moffs_cn = tile_n_idx * BLOCK_SIZE_Nc = accumulator.to(dtype)c_desc.store([offs_cm, offs_cn], c)# go to the next tile by advancing NUM_SMtile_idx += NUM_SM# get ready to go to the next gemm problemlast_problem_end = last_problem_end + num_tilesdef group_gemm_tma_fn(group_A, group_B):assert supports_tma()assert len(group_A) == len(group_B)group_size = len(group_A)A_addrs = []B_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = group_A[i]B = group_B[i]assert A.shape[1] == B.shape[1]M, K = A.shapeN, K = B.shapeC = torch.empty((M, N), device=DEVICE, dtype=A.dtype)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]# note these are device tensorsd_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)# we use a fixed number of CTA, and it's auto-tunable# TMA descriptors require a global memory allocationdef alloc_fn(size: int, alignment: int, stream: Optional[int]):return torch.empty(size, device="cuda", dtype=torch.int8)triton.set_allocator(alloc_fn)grid = lambda META: (META['NUM_SM'], )grouped_matmul_tma_kernel[grid](d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size,FP8=torch.float8_e4m3fn == group_A[0].dtype, NUM_SM=num_sms())return group_Cgroup_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
group_B_T = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):M = group_m[i]N = group_n[i]K = group_k[i]A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-2)if supports_tma():tri_tma_out = group_gemm_tma_fn(group_A, group_B_T)for i in range(group_size):assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=1e-2)# only launch the kernel, no tensor preparation here to remove all overhead
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):grid = lambda META: (META['NUM_SM'], )grouped_matmul_kernel[grid](a_ptrs,b_ptrs,c_ptrs,sizes,lds,group_size,)def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype):grid = lambda META: (META['NUM_SM'], )grouped_matmul_tma_kernel[grid](a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, FP8=torch.float8_e4m3fn == dtype,NUM_SM=num_sms())def torch_perf_fn(group_A, group_B):for a, b in zip(group_A, group_B):torch.matmul(a, b)@triton.testing.perf_report(triton.testing.Benchmark(# argument names to use as an x-axis for the plotx_names=['N'],x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`line_arg='provider',# argument name whose value corresponds to a different line in the plot# possible values for `line_arg``line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []),# label name for the linesline_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []),# line stylesstyles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []),ylabel="runtime(ms)", # label name for the y-axisplot_name="group-gemm-performance",# name for the plot. Used also as a file name for saving the plot.args={},))
def benchmark_square_matrices(N, provider):group_size = 4group_A = []group_B = []group_B_T = []A_addrs = []B_addrs = []B_T_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())B_T_addrs.append(B_T.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [N, N, N]g_lds += [N, N, N]d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)quantiles = [0.5, 0.2, 0.8]if provider == 'cublas':ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)if provider == 'triton-tma':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, dtype=torch.float16), quantiles=quantiles)return ms, max_ms, min_ms@triton.testing.perf_report(triton.testing.Benchmark(# argument names to use as an x-axis for the plotx_names=['M'],x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`line_arg='provider',# argument name whose value corresponds to a different line in the plot# possible values for `line_arg``line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []),# label name for the linesline_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []),# line stylesstyles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []),ylabel="runtime(ms)", # label name for the y-axisplot_name="group-gemm-performance-m-8192-k-8192",# name for the plot. Used also as a file name for saving the plot.args={},))
def benchmark_batches(M, provider):N = 8192K = 8192group_size = 4group_A = []group_B = []group_B_T = []A_addrs = []B_addrs = []B_T_addrs = []C_addrs = []g_sizes = []g_lds = []g_T_lds = []group_C = []for i in range(group_size):A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)C = torch.empty((M, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())B_T_addrs.append(B_T.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)]d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE)quantiles = [0.5, 0.2, 0.8]if provider == 'cublas':ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)if provider == 'triton-tma':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_t_lds, group_size, dtype=torch.float16), quantiles=quantiles)return ms, max_ms, min_msbenchmark_square_matrices.run(show_plots=True, print_data=True)
benchmark_batches.run(show_plots=True, print_data=True)
1. 导入必要的模块和工具函数
from typing import Optional
import torchimport triton
import triton.language as tl
导入了 torch
(PyTorch 框架)、triton
(用于编写和优化 GPU 内核)及其语言模块 tl
。
2. 检查是否使用 CUDA 后端及设备特性
def is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma():return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def num_sms():if is_cuda():return torch.cuda.get_device_properties("cuda").multi_processor_countreturn 148
定义了辅助函数来判断是否使用 CUDA 后端、是否支持 TMA(Tensor Memory Aliasing)以及获取设备的 SM(Streaming Multiprocessor)数量。
3. 定义分组矩阵乘法内核
@triton.autotune(configs=[# ... 配置列表 ...],key=['group_size'],
)
@triton.jit
def grouped_matmul_kernel(# ... 参数列表 ...
):# 内核实现逻辑
利用 triton.autotune
和 triton.jit
装饰器定义了一个自动调优的内核函数,用于执行分组矩阵乘法操作。
具体实现逻辑:
- 遍历每个分组,计算每个矩阵对的大小(M、N、K)和内存布局(leading dimensions)。
- 将矩阵切片为
BLOCK_SIZE_M × BLOCK_SIZE_N
大小的块(tile)。 - 通过循环迭代每个块,从中加载数据并执行矩阵乘法计算。
- 通过
tl.load
和tl.store
操作与设备内存交互。
4. 定义分组矩阵乘法的上层函数
def group_gemm_fn(group_A, group_B):# ... 函数实现 ...return group_C
用于调用上述内核函数执行分组矩阵乘法的上层函数。
具体实现逻辑:
- 验证输入矩阵对数量一致。
- 遍历每个矩阵对,准备分组信息(矩阵大小、内存布局等)。
- 创建输出矩阵并收集所有设备指针。
- 转换为设备张量并调用内核函数。
5. 实现基于 TMA 的分组矩阵乘法版本(若支持 TMA)
@triton.autotune(tma_configs,key=['group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(# ... 参数列表 ...
):# TMA 版本的内核实现逻辑
利用 TMA 技术实现的高性能分组矩阵乘法内核,并通过 triton.set_allocator
设置了专门的内存分配函数以支持 TMA。
6. 测试代码和性能基准
group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
# ... 测试数据准备 ...tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
# ... 结果验证 ...# ... 性能基准测试代码 ...
- 准备了测试数据并调用上述函数进行验证。
- 使用
torch.testing.assert_allclose
验证结果一致性。 - 最后,使用
triton.testing.perf_report
和triton.testing.Benchmark
定义性能测试函数并运行,生成性能报告。
Persistent Matmul
这个就不贴代码了,我刚看到这块内容的时候也有一些好奇,主要阐述一下和普通矩阵乘法的区别:
这两段代码分别实现了非持久化矩阵乘法和持久化矩阵乘法,它们有以下区别:
-
- 持久化矩阵乘法(代码2):通过循环和多个程序 ID(
start_pid
)的使用,实现更细粒度的任务分配。一个程序可以处理多个子矩阵的乘法任务,从而提高资源利用率。
- 持久化矩阵乘法(代码2):通过循环和多个程序 ID(
- 资源利用率不同:
- 非持久化矩阵乘法:由于每个程序处理一个子矩阵,可能存在计算资源未充分利用的情况,特别是在矩阵规模较大时。
- 持久化矩阵乘法:通过循环多次调度程序(使用
tl.range
),可以更有效地利用 GPU 的计算资源,尤其是在大规模矩阵运算中。
- 数据流和存储操作的优化程度不同:
- 非持久化矩阵乘法:在计算完成后,直接存储结果到目标位置。这种设计简单直接,但可能在大规模运算中导致存储操作的不连续性。
- 持久化矩阵乘法:通过延迟存储操作(
tile_id_c += NUM_SMS
后才进行存储),允许计算和存储操作的重叠,从而提高整体效率。
- 硬件调度和性能优化的不同:
- 非持久化矩阵乘法:主要依赖于 GPU 的自动调度机制,对大规模矩阵的适应性可能较差。
- 持久化矩阵乘法:通过显式管理程序调度和计算任务,减少了因硬件调度机制导致的延迟,特别是在大规模矩阵运算中能显著提高性能。
方面 | 非持久化矩阵乘法 | 持久化矩阵乘法 |
---|---|---|
任务分配方式 | 一次性分配,线程块处理单个任务 | 循环分配,线程块处理多个任务 |
线程块调度 | 简单直接,线程块独立处理任务 | 复杂,线程块持续从任务队列获取任务 |
资源利用率 | 较低,处理多个小任务时易出现空闲 | 较高,充分利用 GPU 资源 |
适用场景 | 单个大矩阵乘法任务 | 多个小矩阵乘法任务 |
实现复杂度 | 较低,逻辑简单 | 较高,需要管理任务队列和调度 |
综上所述,持久化矩阵乘法比非持久化矩阵乘法在大规模矩阵运算中更高效,因为它通过更细致的调度和资源管理,充分利用了 GPU 的计算资源,降低了存储操作的延迟。
Block Scaled Matrix Multiplication
CUDA 设备若支持 PTX 8.7 及更高版本,便能利用块缩放矩阵乘法指令。为确保在张量核心矩阵乘法的快速内循环中低延迟访问这些缩放因子,须保证块缩放因子在内存中以连续布局存储,与访问模式相符。
块缩放矩阵乘法的张量核心指令会计算如下乘积:
C = ( A × s c a l e _ a ) @ ( B × s c a l e _ b ) C = (A \times scale\_a) @ (B \times scale\_b) C=(A×scale_a)@(B×scale_b)
其中,( s c a l e a scale_a scalea ) 和 ( s c a l e b scale_b scaleb ) 分别是矩阵 A 和 B 的块缩放因子。在块缩放矩阵乘法下,每个缩放因子会沿着各自的 K 轴广播并乘以矩阵 A 和 B 的元素向量。此处,A 和 B 中每个缩放因子广播的元素数量被称为向量大小(VEC_SIZE)。
在行主序的线性布局中,缩放因子的形状为:
( M , K / / V E C _ S I Z E ) 和 ( N , K / / V E C _ S i z e ) (M, K // VEC\_SIZE) \text{ 和 } (N, K // VEC\_Size) (M,K//VEC_SIZE) 和 (N,K//VEC_Size)
不过,为避免非连续内存访问,将缩放因子存储为打包的块布局更为有利。对于左侧矩阵(LHS),布局如下:
( M 32 × 4 , K V E C _ S I Z E × 4 , 32 , 4 , 4 ) \left( \frac{M}{32 \times 4}, \frac{K}{VEC\_SIZE \times 4}, 32, 4, 4 \right) (32×4M,VEC_SIZE×4K,32,4,4)
如此一来,在 K 块的快速内循环中,每个张量核心 MMA 可连续访问 M 轴上 128 行的缩放因子块,对应矩阵 A 的每个 BLOCK_M x BLOCK_K 子块。
为符合 Triton 语言对 dot_scaled 的语义要求,缩放因子需按上述 5D 布局准备,但随后需逻辑转置并重塑为张量点积期望的 2D 布局。
import argparseimport torch
import triton
import triton.language as tl
import triton.profiler as proton
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensordef is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_block_scaling():return is_cuda() and torch.cuda.get_device_capability()[0] == 10def _matmul_launch_metadata(grid, kernel, args):ret = {}M, N, K = args["M"], args["N"], args["K"]kernel_name = kernel.nameif "ELEM_PER_BYTE_A" and "ELEM_PER_BYTE_B" and "VEC_SIZE" in args:if args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 1:kernel_name += "_mxfp8"elif args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 2:kernel_name += "_mixed"elif args["ELEM_PER_BYTE_A"] == 2 and args["ELEM_PER_BYTE_B"] == 2:if args["VEC_SIZE"] == 16:kernel_name += "_nvfp4"elif args["VEC_SIZE"] == 32:kernel_name += "_mxfp4"ret["name"] = f"{kernel_name} [M={M}, N={N}, K={K}]"ret["flops"] = 2. * M * N * Kreturn ret@triton.jit(launch_metadata=_matmul_launch_metadata)
def block_scaled_matmul_kernel( #a_desc, a_scale, #b_desc, b_scale, #c_desc, #M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, #stride_sk: tl.constexpr, stride_sb: tl.constexpr, stride_sc: tl.constexpr, stride_sd: tl.constexpr,output_type: tl.constexpr, #ELEM_PER_BYTE_A: tl.constexpr, #ELEM_PER_BYTE_B: tl.constexpr, #VEC_SIZE: tl.constexpr, #BLOCK_M: tl.constexpr, #BLOCK_N: tl.constexpr, #BLOCK_K: tl.constexpr, #NUM_STAGES: tl.constexpr, #USE_2D_SCALE_LOAD: tl.constexpr): #if output_type == 0:output_dtype = tl.float32elif output_type == 1:output_dtype = tl.float16elif output_type == 2:output_dtype = tl.float8e4nvpid = tl.program_id(axis=0)num_pid_m = tl.cdiv(M, BLOCK_M)pid_m = pid % num_pid_mpid_n = pid // num_pid_moffs_am = pid_m * BLOCK_Moffs_bn = pid_n * BLOCK_Noffs_k_a = 0offs_k_b = 0## block scale offsetsoffs_sm = (pid_m * (BLOCK_M // 128) + tl.arange(0, BLOCK_M // 128)) % Moffs_sn = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) % NMIXED_PREC: tl.constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2# For now it is recommended to use 2D scale loads for better performance.# In the future we will bring additional optimizations to either allow 5D loads,# the use of TMAs for scale factors, or both.if USE_2D_SCALE_LOAD:offs_inner = tl.arange(0, (BLOCK_K // VEC_SIZE // 4) * 32 * 4 * 4)a_scale_ptr = a_scale + offs_sm[:, None] * stride_sk + offs_inner[None, :]b_scale_ptr = b_scale + offs_sn[:, None] * stride_sk + offs_inner[None, :]else:offs_sk = tl.arange(0, (BLOCK_K // VEC_SIZE // 4))# MN spatial offsets for 32 element blockingoffs_sc = tl.arange(0, 32)# offsets for both scale factor column ID (along K)# and spatial block column ID (along MN)offs_sd = tl.arange(0, 4)a_scale_ptr = a_scale + (offs_sm[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] *stride_sb + offs_sc[None, None, :, None, None] * stride_sc +offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :])b_scale_ptr = b_scale + (offs_sn[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] *stride_sb + offs_sc[None, None, :, None, None] * stride_sc +offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :])accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):a = a_desc.load([offs_am, offs_k_a])b = b_desc.load([offs_bn, offs_k_b])scale_a = tl.load(a_scale_ptr)scale_b = tl.load(b_scale_ptr)if USE_2D_SCALE_LOAD:scale_a = scale_a.reshape(BLOCK_M // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)scale_b = scale_b.reshape(BLOCK_N // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)scale_a = scale_a.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)scale_b = scale_b.trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)if MIXED_PREC:accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e2m1", accumulator)elif ELEM_PER_BYTE_A == 2 and ELEM_PER_BYTE_B == 2:accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)else:accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e4m3", accumulator)offs_k_a += BLOCK_K // ELEM_PER_BYTE_Aoffs_k_b += BLOCK_K // ELEM_PER_BYTE_Ba_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sbb_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sbc_desc.store([offs_am, offs_bn], accumulator.to(output_dtype))def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, configs):output = torch.empty((M, N), dtype=dtype_dst, device="cuda")if dtype_dst == torch.float32:dtype_dst = 0elif dtype_dst == torch.float16:dtype_dst = 1elif dtype_dst == torch.float8_e4m3fn:dtype_dst = 2else:raise ValueError(f"Unsupported dtype: {dtype_dst}")BLOCK_M = configs["BLOCK_SIZE_M"]BLOCK_N = configs["BLOCK_SIZE_N"]c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc, b_scale, c_desc, M, N, K, a_scale.stride(0),a_scale.stride(1), a_scale.stride(2), a_scale.stride(3), dtype_dst,configs["ELEM_PER_BYTE_A"], configs["ELEM_PER_BYTE_B"], configs["VEC_SIZE"],configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"],configs["num_stages"], USE_2D_SCALE_LOAD=True)return outputdef initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference=False):BLOCK_M = 128BLOCK_N = 256BLOCK_K = 256 if "fp4" in block_scale_type else 128VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8", "mixed"], f"Invalid block scale type: {block_scale_type}"ELEM_PER_BYTE_A = 2 if "fp4" in block_scale_type else 1ELEM_PER_BYTE_B = 1 if block_scale_type == "mxfp8" else 2device = "cuda"a_ref = MXFP4Tensor(size=(M, K), device=device).random()# Similar to Hopper's wgmma symmetric fp8 instruction, the RHS is expected# to be in col-major layout for Blackwell's tcgen05.mma when using fp4 operands.# To conform to the expected semantics of tl.dot_scaled, (M, K) x (K, N),# the data is generated in col-major layout, packed along K for fp4, and then# logically transposed. Note that if one operand is of fp8 precision, unlike Hopper,# Blackwell supports both row-major and col-major layouts for the RHS matrix.# For the mixed-precision case, the fp4 RHS can be either in row or col-major layout.# But for performance reason, it is recommended to use col-major layout. If TMA is used# for the fp4 RHS operand load in mixed-precision dot, as in this tutorial, it must be# in col-major layout.b_ref = MXFP4Tensor(size=(N, K), device=device).random()if block_scale_type in ["mxfp8", "mixed"]:a_ref = a_ref.to(torch.float32)a = a_ref.to(torch.float8_e4m3fn)else:# Pack two fp4 elements per byte along Ka = a_ref.to_packed_tensor(dim=1)if block_scale_type == "mxfp8":b_ref = b_ref.to(torch.float32)b = b_ref.to(torch.float8_e4m3fn)else:b = b_ref.to_packed_tensor(dim=1)b_ref = b_ref.to(torch.float32).Ta_desc = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A])if block_scale_type == "mixed":b_desc = TensorDescriptor(b,shape=[N, K // ELEM_PER_BYTE_B],strides=[K // ELEM_PER_BYTE_B, 1],block_shape=[BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],)else:b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B])epsilon = 1e-8a_scale = torch.rand((M // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilonb_scale = torch.rand((N // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilonif block_scale_type == "nvfp4":a_scale = a_scale.to(torch.float8_e4m3fn)b_scale = b_scale.to(torch.float8_e4m3fn)a_scale_ref = a_scaleb_scale_ref = b_scaleelif block_scale_type in ["mxfp4", "mxfp8", "mixed"]:a_scale_ref = MXScaleTensor(a_scale)b_scale_ref = MXScaleTensor(b_scale)a_scale = a_scale_ref.datab_scale = b_scale_ref.datareference = Noneif compute_reference:a_scale_ref = a_scale_ref.to(torch.float32)b_scale_ref = b_scale_ref.to(torch.float32)def unpack_scale(packed):num_chunk_m, num_chunk_k, _, _, _ = packed.shapereturn packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous()a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:M, :K]b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N]reference = torch.matmul(a_ref.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref)configs = {"BLOCK_SIZE_M": BLOCK_M,"BLOCK_SIZE_N": BLOCK_N,"BLOCK_SIZE_K": BLOCK_K,"num_stages": 4,"ELEM_PER_BYTE_A": ELEM_PER_BYTE_A,"ELEM_PER_BYTE_B": ELEM_PER_BYTE_B,"VEC_SIZE": VEC_SIZE,}return a_desc, a_scale, b_desc, b_scale, configs, referencedef validate_block_scaled(M, N, K, block_scale_type="nvfp4"):def alloc_fn(size: int, align: int, _):return torch.empty(size, dtype=torch.int8, device="cuda")if block_scale_type == "mixed":# This is needed for TMA with the descriptor created on the device.# TMA load for mixed-precision fp4 is supported only by device TMA.triton.set_allocator(alloc_fn)a_desc, a_scale, b_desc, b_scale, configs, reference = initialize_block_scaled(M, N, K, block_scale_type,compute_reference=True)output = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)torch.testing.assert_close(reference, output.to(torch.float32), atol=1e-3, rtol=1e-3)print(f"✅ (pass {block_scale_type})")def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):assert K % 128 == 0M = 8192N = 8192print(f"Problem Shape = {M}x{N}x{K}")a_desc, a_scale, b_desc, b_scale, configs, _ = initialize_block_scaled(M, N, K, block_scale_type,compute_reference=False)_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)proton.activate(0)for _ in range(reps):_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)proton.deactivate(0)print("Done benchmarking")def show_profile(profile_name):import triton.profiler.viewer as proton_viewermetric_names = ["time/ms"]metric_names = ["tflop/s"] + metric_namesfile_name = f"{profile_name}.hatchet"tree, metrics = proton_viewer.parse(metric_names, file_name)proton_viewer.print_tree(tree, metrics)if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("-K", type=int, required=False, default=512)parser.add_argument("--K_range", type=int, nargs=2)parser.add_argument("--K_step", type=int, default=512)parser.add_argument("--bench", action="store_true", default=True)parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8", "mixed"], default="nvfp4")args = parser.parse_args()if not supports_block_scaling():print("⛔ This example requires GPU support for block scaled matmul")else:if args.K and args.K_range is None:args.K_range = [args.K, args.K]args.K_step = 1 # doesn't matter as long as it's not 0torch.manual_seed(42)validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format)if args.bench:proton.start("block_scaled_matmul", hook="triton")proton.deactivate(0) # Skip argument creationfor K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):bench_block_scaled(K, reps=10000, block_scale_type=args.format)proton.finalize()show_profile("block_scaled_matmul")
总结
-
通过了一些练习学习了triton的基础语法和一些gpu的知识,还有些不懂的继续学习之后再回忆理解一下
-
会继续跟进这块方向的知识,构建起完整的知识树
Reference
- 从啥也不会到CUDA GEMM优化
- Tutorials — Triton documentation