当前位置: 首页 > news >正文

PTX指令集基础以及warp级矩阵乘累加指令介绍

文章目录

  • PTX指令集基础
      • 基本指令格式
      • 寄存器
      • 举例
      • 指令和操作数的类型
    • 16bit的机器特定语义
  • warp级矩阵乘累计指令
    • WMMA
      • Integer Fragment
        • 乘法矩阵
        • 累加器矩阵
      • 低秩对齐
      • wmma.load
      • wmma.store
      • wmma.mma
    • MMA
      • 语法
        • 半精度数据类型
        • 带block_scaling的半精度数据类型
      • 举例
      • ldmatrix
        • 语法
        • 在cuda kernel中使用
      • stmatrix
        • 语法
        • 在cuda kernel中使用

本文主要基于cuda官方文档:URL

PTX指令集基础

基本指令格式

PTX 指令:opcode是具体操作指令, 以及后面跟随0 到 4 个操作数,此外,在操作指令左侧可通过 @ 符号添加一个可选谓词,具体形式如下:

  • @p opcode;
  • @p opcode a;
  • @p opcode d, a;
  • @p opcode d, a, b;
  • @p opcode d, a, b, c;

说明:

  • @p用来控制指令的指定条件,@!p表示取反。

  • d 表示目标操作数

  • abc 表示源操作数

特殊情况:

  1. 下面指令会写入两个目标寄存器,用 | 符号分隔多个目标寄存器。
    示例:setp.lt.s32 p|q, a, b; // p = (a < b); q = !(a < b);

寄存器

谓词寄存器

谓词寄存器专门用于存储布尔值(真 / 假),常用于条件执行。谓词寄存器的声明格式如下:

.reg .pred p, q, r;

  • .pred是关键字,表示寄存器类型是谓词寄存器

通用寄存器

用于存储多种类型的数据(如整数、浮点数地址等),支持各种算术和逻辑运算。声明格式如下:

.reg .s32 a, b;

其声明了两个 32 位整数寄存器ab

举例

对于高级代码:

if (i < n)j = j + 1;

其相应的PTX代码如下:

      setp.lt.s32  p, i, n;    // p = (i < n),执行有符号 32 位整数的小于比较(Less Than)并设置相应的谓词寄存器
@p    add.s32      j, j, 1;    // if i < n, add 1 to j

条件分支PTX代码:

      setp.lt.s32  p, i, n;    // p = (i < n),
@!p   bra  L1;                 // if p 为 False, 跳转(bra)到标签L1对应的指令add.s32      j, j, 1;    // if p 为 True,执行该条指令
L1:     ...

指令和操作数的类型

指令必须包含数据类型修饰符。示例:

.reg .u16 d, a, b;   // 寄存器声明add.u16 d, a, b;    // 执行16位无符号加法

有些指令需要多个数据类型修饰符,最典型的是数据转换指令cvt。它需要为结果和源分别指定数据类型修饰符,且修饰符的顺序与操作数的顺序一致。例如:

.reg .u16 a;
.reg .f32 d;cvt.f32.u16 d, a;   // 将16位无符号整数a转换为32位浮点数d

通常情况下,操作数的数据类型必须与对应的指令类型修饰符一致。操作数与指令类型的一致性规则如下:

  • 位大小类型与相同大小的任何类型一致。
  • 有符号和无符号整数类型在大小相同时一致,且必要时整数操作数会被隐式转换为指令类型。例如,在有符号整数指令中使用无符号整数操作数时,该操作数会被指令当作有符号整数处理。
  • 浮点类型仅在大小相同时才一致,即它们必须完全匹配。

当指令类型和操作数类型不匹配时:

  • ld、st 和 cvt 指令允许源操作数和目的操作数的数据宽度大于指令类型的大小
  • 位大小和整数(有符号与无符号)指令类型的操作数类型检查规则有所放宽;浮点指令类型仍要求操作数类型大小完全匹配,除非操作数是位大小类型
  • 当源操作数大小超过指令类型大小时,源数据会被截断至指令类型大小所指定的相应位数

16bit的机器特定语义

PTX程序在支持不同数据类型的GPU上的执行情况:

  • PTX程序可在16位或32位数据类型的GPU上运行。在32位数据类型上,PTX的16位寄存器会映射到32位物理寄存器,16位计算也会提升为32位计算,这可能导致同一代码在16位和32位机器上的计算结果出现差异,额外的精度位可能在应用层面显现(如右移指令)。

  • 从PTX语言层面,一种解决方案是为16位代码定义与16位数据类型执行一致的语义,但这会使16位代码在32位数据类型上执行时产生性能损失(需额外屏蔽指令处理高位多余精度位)。

  • 因此,PTX中16位指令的语义是机器特定的,以避免32位GPU上16位代码的性能损失。编译器或程序员可通过在程序适当位置添加显式16位转换来保证代码可移植性,但这对许多性能关键型应用并不理想,且很多应用更能接受执行差异而非限制性能。

warp级矩阵乘累计指令

特性wmmamma
内存布局管理隐式处理:使用wmma.loadwmma.store指令时,会隐式地将矩阵元素从内存组织为适合 GPU 计算的格式,开发者无需手动管理线程间的数据分配,简化了编程模型。显式处理:需要显式处理矩阵元素在 warp 内各线程间的分布。开发者必须手动将矩阵分块并分配到不同线程,控制数据的加载和存储方式。
稀疏矩阵支持仅密集矩阵密集矩阵和结构化稀疏矩阵
编程复杂度
指令集层次高级抽象底层指令
适用场景快速开发、密集矩阵高性能计算、稀疏矩阵

选择建议

  • 若追求开发效率且处理密集矩阵,优先使用wmma
  • 若需要处理稀疏矩阵或追求极致性能优化,选择mma
  • 在Hopper上,为了获得最高的性能,应该使用wgmma指令。

WMMA

  • 每个线程持有一个片段:Warp 中的 32 个线程,每个线程都会 “拿” 到矩阵的一小部分(即 “片段”),32 个线程的片段合起来才能构成完整的矩阵(或矩阵的一个分块)。
  • 片段分布 “不透明”:片段在矩阵中的具体位置(比如哪个线程拿了矩阵的哪几行哪几列)是由 GPU 硬件架构(如 SM 7.0、8.0 等)决定的,用户不需要关心具体分布,只需按规则使用即可(这也是 “分布未指定” 的含义)。
  • 片段可以在多个 wmma 操作中传递使用,但有严格的 “兼容性要求”:
    • 必须匹配的条件:如果一个片段要作为另一个 wmma 操作的输入,两个操作涉及的矩阵必须满足:形状、布局、元素类型
    • 如果两个函数编译时针对的 GPU 架构(SM 版本)不同,即使链接在一起,它们的片段也可能不兼容(因为不同架构对片段的内部布局设计可能不同)。
    • 不要把片段传给带有.weak链接属性的函数:这种函数在链接时可能被替换成其他模块的版本,而不同模块的片段布局可能不匹配,导致错误。

Integer Fragment

乘法矩阵
形状(m x n x k)矩阵A的片段构成矩阵B的片段构成背后逻辑
.m16n16k162个.b32寄存器,每个存4个元素2个.b32寄存器,每个存4个元素- .b32是32位寄存器,每个能存4个8位元素(4×8=32);
- A和B的每个线程各拿2×4=8个元素,32个线程共拿32×8=256个元素,正好匹配16×16矩阵的大小(16×16=256)。
.m8n32k161个.b32寄存器,存4个元素4个.b32寄存器,每个存4个元素- A是8×16矩阵,每个线程拿4个元素,32线程共拿32×4=128个元素(8×16=128);
- B是16×32矩阵,每个线程拿4×4=16个元素,32线程共拿32×16=512个元素(16×32=512)。
累加器矩阵

矩阵乘法的结果(或中间累加结果)存在累加器中,元素类型是.s32(32位有符号整数,位数更高是为了避免计算时溢出)。

数据类型形状(m x n x k)片段构成解读
.s32.m16n16k168个.s32寄存器16x16的结果矩阵,每个线程负责的片段需要用8个32位寄存器存储(32位寄存器直接存32位整数,无需像8位那样“打包”)。
.s32.m8n32k168个.s32寄存器每个线程负责的片段需要用8个32位寄存器存储

低秩对齐

考虑下面的指令:

wmma.load.a.sync.aligned.row.m16n16k16.f16 {x0,...,x7}, [p], s;
  • 段大小是32字节,每段可以存储8个.fp16×2数据(两个.fp16的向量)
  • 指令中的s是 “步长”,但它的单位是.f16元素(不是字节)。因为 1 个.f16是 2 字节,所以将步长s转换为字节就是2×s
  • 因为片段大小是 32 字节,所以每一行的起始地址必须是 32 的倍数。
    • 基指针p是第一行的其实地址,p必须是32的倍数。
    • 第二行的其实地址是p+2×s,因此s也必须是32的倍数。
  • 为什么要进行地址对齐?
    • 若地址对齐,硬件可以一次读取完整的段,效率极高。
    • 若地址不对齐,硬件可能需要分两次读取,然后裁剪,再拼接数据,会降低性能,甚至部分硬件不支持非对齐访问,直接导致程序出错。

wmma.load

.f16浮点类型数据load指令:

wmma.load.a.sync.aligned.layout.shape{.ss}.atype r, [p] {, stride};
wmma.load.b.sync.aligned.layout.shape{.ss}.btype r, [p] {, stride};
wmma.load.c.sync.aligned.layout.shape{.ss}.ctype r, [p] {, stride};.layout = {.row, .col};
.shape  = {.m16n16k16, .m8n32k16, .m32n8k16};
.ss     = {.global, .shared{::cta}};
.atype  = {.f16, .s8, .u8};
.btype  = {.f16, .s8, .u8};
.ctype  = {.f16, .f32, .s32};
  • p是起始地址
  • r是加载到的目标寄存器
  • .ss表示支持从global memory或者share memory加载数据
  • {, stride} 表示可选参数

举例:

// 为当前线程分配 8 个 32 位的寄存器
// x<8>:声明一组寄存器变量,变量名前缀为x,数字范围从 0 到 7(即 x0, x1, x2, ..., x7),总共 8 个寄存器
.reg .b32 x<8>;
// 从行主序的f16矩阵B加载元素到寄存器
wmma.load.b.sync.aligned.m16n16k16.row.f16 {x0,x1,x2,x3,x4,x5,x,x7}, [ptr];.reg .b32 x<8>;
// 从f32列主序矩阵C加载元素并缩放这些值
wmma.load.c.sync.aligned.m16n16k16.col.f32 {x0,x1,x2,x3,x4,x5,x6,x7}, [ptr];// 对于C,现在使用{x0, ..., x7}进行实际的wmma.mma运算
mul.f32 x0, x0, 0.1;
// 对所有寄存器x<8>重复此操作;
...
mul.f32 x7, x7, 0.1;
// 现在使用{x0, ..., x7}进行实际的wmma.mma运算.reg .b32 x<4>  // 寄存器x<4>各自包含四个打包的.u8值
// 从整型矩阵A加载元素
wmma.load.a.sync.aligned.m32n8k16.row.u8 {x0,x1,x2,x3}, [ptr];.reg .b32 x<4>;
// 从.bf16矩阵A加载元素:
wmma.load.a.sync.aligned.m16n16k16.row.bf16 {x0,x1,x2,x3}, [ptr];

wmma.store

指令语法:

wmma.store.d.sync.aligned.layout.shape{.ss}.type [p], r {, stride};.layout = {.row, .col};
.shape  = {.m16n16k16, .m8n32k16, .m32n8k16};
.ss     = {.global, .shared{::cta}};
.type   = {.f16, .f32, .s32};

举例:

.reg .b32 x<8>;
// 执行矩阵乘法累加操作 D += A×B。
wmma.mma.sync.m16n16k16.row.col.f32.f32 {d0, d1, d2, d3, d4, d5, d6, d7}, ...;
wmma.store.d.sync.m16n16k16.row.f32 [ptr], {d0, d1, d2, d3, d4, d5, d6, d7};// Store s32 accumulator for m16n16k16 shape:
.reg .b32 d<8>;
wmma.store.d.sync.aligned.m16n16k16.row.s32 [ptr], {d0, d1, d2, d3, d4, d5, d6, d7};

wmma.mma

指令语法:

// Floating point (.f16 multiplicands) wmma.mma
wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype d, a, b, c;// Integer (.u8/.s8 multiplicands) wmma.mma
wmma.mma.sync.aligned.alayout.blayout.shape.s32.atype.btype.s32{.satfinite} d, a, b, c;.alayout = {.row, .col};
.blayout = {.row, .col};
.shape  =  {.m16n16k16, .m8n32k16, .m32n8k16};
.dtype   = {.f16, .f32};
.atype   = {.s8, .u8};
.btype   = {.s8, .u8};
.ctype   = {.f16, .f32};

举例:

.global .align 32 .f16 A[256], B[256];  // 32表示按照32位对齐
.global .align 32 .f32 C[256], D[256];
.reg .b32 a<8> b<8> c<8> d<8>;wmma.load.a.sync.aligned.m16n16k16.global.row.f16{a0, a1, a2, a3, a4, a5, a6, a7}, [A];
wmma.load.b.sync.aligned.m16n16k16.global.col.f16{b0, b1, b2, b3, b4, b5, b6, b7}, [B];wmma.load.c.sync.aligned.m16n16k16.global.row.f32{c0, c1, c2, c3, c4, c5, c6, c7}, [C];wmma.mma.sync.aligned.m16n16k16.row.col.f32.f32{d0, d1, d2, d3, d4, d5, d6, d7},{a0, a1, a2, a3, a4, a5, a6, a7},{b0, b1, b2, b3, b4, b5, b6, b7},{c0, c1, c2, c3, c4, c5, c6, c7};wmma.store.d.sync.aligned.m16n16k16.global.col.f32[D], {d0, d1, d2, d3, d4, d5, d6, d7};

MMA

语法

半精度数据类型
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;.alayout = {.row, .col};  # 指定矩阵 A 的存储布局:行优先
.blayout = {.row, .col};  # 指定矩阵 B 的存储布局:行优先
.ctype   = {.f16, .f32};  # 指定操作数C的数据类型,包含两种
.dtype   = {.f16, .f32};  # 指定操作数D数据类型,包含两种
带block_scaling的半精度数据类型
# block_scale指示了在进行矩阵乘累加之前要对矩阵A和B分别用scale_A和scale_B进行scale
# .stype指定了scale_A和scale_B的数据类型
# .scale_vec_size指定了scale_A的列数和scale_B的行数
# 对于指定.kind::mxf4的mma,当未指定限定符.scale_vec_size时,其默认值为2X。相反,当.kind指定为.kind::mxf8f6f4时,限定符.scale_vec_size默认值为1X。但是,对于.kind::mxf4nvf4,必须提供有效的.scale_vec_size
mma.sync.aligned.m16n8k64.row.col.kind.block_scale{.scale_vec_size}.f32.e2m1.e2m1.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};.kind           = {.kind::mxf4};
.scale_vec_size = {.scale_vec::2X};
.stype          = {.ue8m0};mma.sync.aligned.m16n8k64.row.col.kind.block_scale.scale_vec_size.f32.e2m1.e2m1.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};.kind           = {.kind::mxf4nvf4};
.scale_vec_size = {.scale_vec::2X, .scale_vec::4X};
.stype          = {.ue8m0, .ue4m3};mma.sync.aligned.m16n8k32.row.col.kind.block_scale{.scale_vec_size}.f32.f8f6f4type.f8f6f4type.f32.stype d, a, b, c, scale-a-data, {byte-id-a, thread-id-a}, scale-b-data, {byte-id-b, thread-id-b};.kind           = {.kind::mxf8f6f4};
.scale_vec_size = {.scale_vec::1X};
.f8f6f4type     = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
.stype          = {.ue8m0};

举例

# 全半精度 (f16) 矩阵乘法
.reg .f16x2 %Ra<2> %Rb<2> %Rc<4> %Rd<4>
mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16
{%Rd0, %Rd1, %Rd2, %Rd3},
{%Ra0, %Ra1},
{%Rb0, %Rb1},
{%Rc0, %Rc1, %Rc2, %Rc3};# 混合精度 (f16 输入,f32 输出)
.reg .f16x2 %Ra<2> %Rb<2> %Rc<4>
.reg .f32 %Rd<8>
mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f16
{%Rd0, %Rd1, %Rd2, %Rd3, %Rd4, %Rd5, %Rd6, %Rd7},
{%Ra0, %Ra1},
{%Rb0, %Rb1},
{%Rc0, %Rc1, %Rc2, %Rc3};

ldmatrix

从共享内存中集中加载一个或多个矩阵,来执行mma指令。

语法
# shape指定加载的矩阵的shape
# r是寄存器,p是操作数
# .num是单次指令加载的矩阵片数
# .ss表示共享内存空间
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];# dst_fmt数据在寄存器中的存储格式
# src_fmt数据在内存中的存储格式
ldmatrix.sync.aligned.m8n16.num{.ss}.dst_fmt.src_fmt        r, [p];# .trans表示加载时对数据进行转置
ldmatrix.sync.aligned.m16n16.num.trans{.ss}.dst_fmt.src_fmt r, [p];.shape   = {.m8n8, .m16n16};
.num     = {.x1, .x2, .x4};
.ss      = {.shared{::cta}};
.type    = {.b16, .b8};
.dst_fmt = { .b8x16 };
.src_fmt = { .b6x16_p32, .b4x16_p64 };
在cuda kernel中使用
#include <cstdint>
#include <iostream>// 定义一个设备端内联函数,用于从共享内存加载8x8矩阵
// d0: 输出参数,用于存储加载的数据
// address: 输入参数,共享内存中的地址
__device__ __forceinline__ void ldmatrix_sync_aligned_m8n8_x1_b16(uint32_t &d0, const uint32_t &address) {// 使用内联PTX汇编指令加载矩阵// ldmatrix.sync.aligned.m8n8.x1.shared.b16: 同步加载8x8矩阵,每个元素16位// {%0}: 输出寄存器,存储加载的数据// [%1]: 输入寄存器,指定共享内存地址asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];": "=r"(d0)    // 输出约束,表示d0是一个输出寄存器: "r"(address)); // 输入约束,表示address是一个输入寄存器
}__global__ void ldmatrix(uint16_t *value) {constexpr int N = 64;// 声明共享内存数组__shared__ uint16_t smem[N];auto tid = threadIdx.x;// 计算行偏移量:每个线程负责8个元素,所以乘以8const uint32_t offset_rows = sizeof(uint16_t) * (tid % 8) * 8;// 计算最终地址:共享内存基址 + 行偏移// smem是通用指针类型,转换之后,显示的标记为共享内容的指针类型// 为什么编译器不自动转换?在一些场景中,编译器可能无法处理:内联汇编;地址计算const uint32_t address = __cvta_generic_to_shared(smem) + offset_rows;// 初始化共享内存for (uint32_t i = tid; i < N; i += blockDim.x) {smem[i] = i;}__syncthreads();// 声明用于存储加载数据的变量uint32_t frag;// 调用矩阵加载函数ldmatrix_sync_aligned_m8n8_x1_b16(frag, address);// 再次同步,确保所有线程都完成加载__syncthreads();// 从32位数据中提取两个16位值// 提取低16位uint16_t number1 = static_cast<uint16_t>(frag & 0xFFFF);// 提取高16位uint16_t number2 = static_cast<uint16_t>((frag >> 16) & 0xFFFF);// 打印结果printf("%d -> %d  %d   %d   \n", tid, (int)(smem[2 * tid]), (int)number1, (int)number2);
}// 主函数
int main() {// 声明设备端指针uint16_t *d_value;// 分配设备内存cudaMalloc(&d_value, sizeof(uint16_t));// 启动核函数,使用1个块,32个线程ldmatrix<<<1, 32>>>(d_value);// 等待设备完成cudaDeviceSynchronize();// 释放设备内存cudaFree(d_value);return 0;
}

stmatrix

存储一个或者多个矩阵到共享内存

语法
stmatrix.sync.aligned.shape.num{.trans}{.ss}.type [p], r;.shape  = {.m8n8, .m16n8};
.num    = {.x1, .x2, .x4};
.ss     = {.shared{::cta}};
.type   = {.b16, .b8};
在cuda kernel中使用
__device__ __forceinline__ void stmatrix_sync_aligned_m8n8_x1_b16(uint32_t &d0, const uint32_t &address) {asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(address),"r"(d0));
}__global__ void stmatrix(uint16_t *value) {constexpr int N = 64;__shared__ uint16_t smem[N];auto tid = threadIdx.x;const uint32_t offset_rows = sizeof(uint16_t) * (tid % 8) * 8;const uint32_t address = __cvta_generic_to_shared(smem) + offset_rows;uint32_t frag = 0x00000000;frag |= (tid * 2 + 0);frag |= (tid * 2 + 1) << 16;__syncthreads();stmatrix_sync_aligned_m8n8_x1_b16(frag, address);__syncthreads();uint16_t number1 = static_cast<uint16_t>(frag & 0xFFFF);uint16_t number2 = static_cast<uint16_t>((frag >> 16) & 0xFFFF);printf("%d -> %d  %d   %d   \n", tid, (int)(smem[2 * tid]), (int)number1,(int)number2);
}
http://www.dtcms.com/a/303272.html

相关文章:

  • 5G MBS(组播广播服务)深度解析:从标准架构到商用实践
  • 机器学习(重学版)基础篇(算法与模型一)
  • Ansible提权sudo后执行报错
  • 大语言模型(LLM):重塑AI未来的技术革命
  • 11-大语言模型—Transformer 盖楼,BERT 装修,RoBERTa 直接 “拎包入住”|预训练白话指南
  • 【笔记】Handy Multi-Agent Tutorial 第四章 CAMEL框架下的RAG应用(实践)
  • 四、计算机组成原理——第5章:存储系统
  • C++20实战FlamingoIM开发
  • Vue插件与组件核心区别详解
  • 重复文件清理工具,附免费链接
  • 1.2.3 混凝土结构设计构造要求
  • Cesium大气散射效果
  • Datawhale AI夏令营:大模型理解列车排期表
  • Vue组件通信的终极指南
  • Spark SQL 聚合函数:collect_list 与 collect_set 用法详解
  • ADAS测试:如何用自动化手段提升VV效率
  • 04动手学深度学习(下)
  • AOP实现接口幂等
  • B树、B+树、红黑树区别
  • Solidity基础(教程④-ERC-4626收益金库)
  • 被困扰的elementplus样式修改问题:select选择器修改和el-input修改
  • PHP企业级应用架构:微服务通信、分布式事务与性能优化
  • 短剧系统开发上线全流程攻略:从架构设计到性能优化
  • 页面性能优化
  • SpringBoot轻松集成豆包AI
  • Cacti RCE漏洞复现
  • Android Studio关于Connection refused: connect报错
  • “车位到车位”自动驾驶真相
  • pcm,msd调制解调仿真
  • WCF服务通信框架