当前位置: 首页 > news >正文

cuda编程笔记(19)-- Transformer注意力机制的实现

Transformer里,注意力机制是核心组件,本文将用cuda手写注意力机制模块。

由于Encoder和Decoder里的注意力有掩码之分,在本文实现里,统一忽略掩码;

训练时QKV的L都是相同的,但是推理时不同,本文暂且只写forward,所以L也默认都是一样的

先理清注意力层的步骤

  1. 输入:q,k,v,维度都是[B, L, D]
  2. 对q,k,v作线性变换,Q=q*q_weight,其他同理,得到Q,K,V,维度依然是[B,L,D];权重的维度为[D,D]
  3. Q,K,V进行多头切分,维度转变为[B,H,L,d_head]其中D=H*d_head
  4. output=Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_{head}}})V
  5. 对output进行合并,concat_output维度转回[B,L,D]
  6. 再最后对concat_output作线性映射,依旧是[B,L,D]*[D,D]=[B,L,D]

线性变换怎么做?

此线性变换,实际上只作用于D这个维度;但是cuda库没用提供三维乘二维的api,但是我们可以将这个变换理解为[B*L,D]*[D,D]如此即可使用cublas库的cublasSgemm作矩阵乘法

多头切分怎么切?

切分逻辑实际上是 [B,L,D] → [B,L,H,d_head]

但是我们却换了个维度访问方式,变为[B,H,L,d_head]

如果仅仅是访问的话,也确实可以自己手动改一下访问逻辑;但是问题是我们之后要按这个维度逻辑作矩阵的乘法,如果内存不变的话,直接使用cublas的乘法库将会得到错误的结果。所以必须要对内存重新进行排布了。

__global__ void split_heads(const float *in,float *out,int B,int H,int L,int d_head){int idx=blockDim.x*blockIdx.x+threadIdx.x;int D=H*d_head;int total=B*D*L;if(idx>=total) return;//当前线程对应in的下标int d=idx%D;//d是D维度的下标int l=(idx/D)%L;int b=idx/(L*D);//计算头h的下标int h=d/d_head;//d变为d_head维度的下标d=d%d_head;int out_offset=b*(L*D)+h*(L*d_head)+l*d_head+d;out[out_offset]=in[idx];
}

Softmax怎么做?

在以前的文章里其实也出现过Softmax,总共出现了两种做法:

1.一个block,block内的线程用warp加速。但是问题是B*L*D结果可能比较大,一个block内放不下这么多线程。

2.多block,一个block内一个线程,一个block负责一个batch。这种方法的问题是block内部只有一个线程,几乎没有用到GPU的并行性。

为什么不使用多block多线程呢?因为block之间的数据是没办法在一个核函数里做Softmax的,必须在主机上再发起一次Softmax任务。所以得调用两次核函数。

幸好cudnn为我们提供了封装好的Softmax函数,不用我们自己去封装了。

cudnnStatus_t cudnnSoftmaxForward(cudnnHandle_t handle,cudnnSoftmaxAlgorithm_t algo,cudnnSoftmaxMode_t mode,const void *alpha,const cudnnTensorDescriptor_t xDesc,const void *x,const void *beta,const cudnnTensorDescriptor_t yDesc,void *y
);

algo (算法)

枚举类型 cudnnSoftmaxAlgorithm_t,决定数值稳定性与速度:

  • CUDNN_SOFTMAX_FAST

    • 速度快,但数值稳定性较差。

  • CUDNN_SOFTMAX_ACCURATE

    • 使用更稳定的计算方式(减去最大值再 exp),常用。

  • CUDNN_SOFTMAX_LOG

    • 输出 log-softmax(对数形式)。

👉 一般训练时用 ACCURATE,推理可以选 FAST

mode (归一化模式)

枚举类型 cudnnSoftmaxMode_t,决定 在哪个维度上归一化

  • CUDNN_SOFTMAX_MODE_INSTANCE

    • 在输入张量的最后一维上做 softmax。

    • 例如输入 [batch, channels,h,w],则对每个样本batch的通道做 softmax(覆盖 C×H×W)。

  • CUDNN_SOFTMAX_MODE_CHANNEL

    • 在每个通道维度上做 softmax。

    • 例如输入 [batch, channels, h, w],则对每个位置的 channels 做 softmax。

那我们的归一化是在哪个维度做的?

这里QKT计算出的维度是[B,H,L,L],Softmax归一化是在最后一个维度上进行的,也就是每L个元素进行一次Softmax;

所以即可以选择[B*H*L,L,1,1]配合CUDNN_SOFTMAX_MODE_CHANNEL;

矩阵乘法怎么乘

QK^T乘法的维度变化是[B,H,L,d_head]*[B,H,d_head,L]=[B,H,L,L],并不是所谓完全的转置;那这个乘法要怎么做呢?如果纯手写kernel,将会比较复杂,相当于做了 B*H次矩阵乘法,这个可以留作以后加速的研究;

实际上,cublas有对应的函数给我们使用

cublasSgemmStridedBatched,这是 cuBLAS 提供的批量矩阵乘法函数,非常适合 多个矩阵同样形状同时做 GEMM 的场景

cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle,        // cuBLAS 句柄cublasOperation_t transA,     // A 是否转置cublasOperation_t transB,     // B 是否转置int m,                        // C 的行数int n,                        // C 的列数int k,                        // 矩阵乘法 A*B 的公共维度const float *alpha,            // 缩放系数 αconst float *A, int lda, long long int strideA,  // A 指针、ld、跨步const float *B, int ldb, long long int strideB,  // B 指针、ld、跨步const float *beta,             // 缩放系数 βfloat *C, int ldc, long long int strideC,        // 输出 C 指针、ld、跨步int batchCount                // 批次数量
);

这里需要理解一下m,n,k和stride的含义

就拿[B,H,L,d_head]*[B,H,d_head,L]=[B,H,L,L]作例子;我们总共算了B*H批的[L,d_head]*[d_head,L]=[L,L]的GEMM

这里m,n,k对于的就是cublasSgemm里矩阵的维度,可以参考cuda编程笔记(11)--学习cuBLAS使用-CSDN博客这篇文章

而对于strideA,含义是:矩阵A连续批之间在内存中的元素数。说人话就是一次矩阵乘法涉及了多少元素,对于A来说肯定是L*d_head

注意这个函数也是列优先访问的,所以对于行优先的矩阵乘法,依然要对参数的位置做一些调整,具体做法依然参考上面的文章

拼接怎么拼

拼接和切分同理了,需要内存转移

//拼接 [B, H, L, d_head] → [B, L, D]
__global__ void concat_heads(const float *in,float *out,int B,int H,int L,int d_head){int idx=blockIdx.x*blockDim.x+threadIdx.x;int total=B*H*L*d_head;if(idx>=total) return;//计算本线程对应in的下标int d=idx%d_head;int l=(idx/d_head)%L;int h=(idx/(d_head*L))%H;int b=idx/(d_head*L*H);//对应的out的下标int D=H*d_head;int out_offset=b*(L*D)+l*D+h * d_head + d;out[out_offset]=in[idx];
}

模块化实现

class MultiHeadAttention{
public:MultiHeadAttention(cublasHandle_t &cublas_handle_,cudnnHandle_t &cudnn_handle_,int batch_,int length_,int d_model_,int head_);void forward(float *q_input_,float *k_input_,float *v_input_){q_input=q_input_,k_input=k_input_,v_input=v_input_;const float alpha=1.0f,beta=0.0f;//Q=q_input*weight //[B*L,D]=[B*L,D]*[D,D] //对输入做线性变换,只作用最后一个维度cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,q_weight,d_model,q_input,d_model,&beta,Q,d_model);cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,k_weight,d_model,k_input,d_model,&beta,K,d_model);cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,v_weight,d_model,v_input,d_model,&beta,V,d_model);//多头:[B,L,D]->[B,H,L,d_head]split(Q,Q_split,batch,head,length,d_head);split(K,K_split,batch,head,length,d_head);split(V,V_split,batch,head,length,d_head);//Q × Kᵀ:[B, H, L, d_head] × [B, H, d_head, L] = [B, H, L, L]cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_T,CUBLAS_OP_N,length,length,d_head,&alpha,K_split,d_head,length*d_head,Q_split,d_head,length*d_head,&beta,score,length,length*length,batch*head);//÷根号d_headscale_kernel<<<(batch*head*length*length+255)/256,256>>>(score,batch*head*length*length,1.0/sqrt(d_head));//softmaxcudnnSoftmaxForward(cudnn_handle,CUDNN_SOFTMAX_ACCURATE,CUDNN_SOFTMAX_MODE_CHANNEL,&alpha,score_desc,score,&beta,score_desc,score);//output=score*V:[B,H,L,L] × [B,H,L,d_head]=[B,H,L,d_head]cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_head,length,length,&alpha,V_split,d_head,length*d_head,score,length,length*length,&beta,output,d_head,length*d_head,batch*head);//拼接concat_heads<<<(batch*head*length*d_head+255)/256,256>>>(output,concat_output,batch,head,length,d_head);//多头拼接之后,再映射回 d_model 维度。//concat_output=output*concat_weightcublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,concat_weight,d_model,concat_output,d_model,&beta,concat_output,d_model);}~MultiHeadAttention(){cudaFree(q_weight);cudaFree(k_weight);cudaFree(v_weight);cudaFree(concat_weight);cudaFree(Q);cudaFree(K);cudaFree(V);cudaFree(Q_split);cudaFree(K_split);cudaFree(V_split);cudaFree(score);cudaFree(output);cudaFree(concat_output);cudnnDestroyTensorDescriptor(score_desc);}
private:cublasHandle_t &cublas_handle;cudnnHandle_t &cudnn_handle;int batch,length,d_model,head,d_head;//[B,L,D]->[B,H,L,d_head]float *q_weight,*k_weight,*v_weight,*concat_weight;//暂时不带偏置float *q_input,*k_input,*v_input;float *Q,*K,*V;float *Q_split,*K_split,*V_split;//切分的中间值float *score,*output,*concat_output;//score是中间Q × Kᵀ的值,形状是[B, H, L, L]cudnnTensorDescriptor_t score_desc;void split(const float *in,float *out,int B,int H,int L,int d_head){split_heads<<<(B*H*L*d_head+255)/256,256>>>(in,out,B,H,L,d_head);}
};MultiHeadAttention::MultiHeadAttention(cublasHandle_t &cublas_handle_,cudnnHandle_t &cudnn_handle_,int batch_,int length_,int d_model_,int head_):
cublas_handle(cublas_handle_),cudnn_handle(cudnn_handle_),batch(batch_),length(length_),d_model(d_model_){d_head=d_model/head;//分配内存cudaMalloc(&q_weight,d_model*d_model*sizeof(float));cudaMalloc(&k_weight,d_model*d_model*sizeof(float));cudaMalloc(&v_weight,d_model*d_model*sizeof(float));cudaMalloc(&concat_weight,d_model*d_model*sizeof(float));cudaMalloc(&Q,batch*length*d_model*sizeof(float));cudaMalloc(&K,batch*length*d_model*sizeof(float));cudaMalloc(&V,batch*length*d_model*sizeof(float));cudaMalloc(&Q_split,batch*length*d_model*sizeof(float));cudaMalloc(&K_split,batch*length*d_model*sizeof(float));cudaMalloc(&V_split,batch*length*d_model*sizeof(float));cudaMalloc(&score,batch*head*length*length*sizeof(float));cudaMalloc(&output,batch*length*d_model*sizeof(float));cudaMalloc(&concat_output,batch*length*d_model*sizeof(float));//初始化init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (q_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (k_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (v_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (concat_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1//张量初始化cudnnCreateTensorDescriptor(&score_desc);cudnnSetTensor4dDescriptor(score_desc,CUDNN_TENSOR_NCHW,CUDNN_DATA_FLOAT,batch*head*length,length,1,1);cudaDeviceSynchronize();
}    

暂时的不足

只写了 forward,没有 backward

  • 在 PyTorch / TensorFlow 这种框架里,写 forward 足够了,因为 autograd 会追踪运算图,自动生成梯度;

  • 但在 CUDA + C++ 框架里写 Layer,就得 手动实现 backward,否则训练用不了。

多头注意力有多个输入,不太符合 Layer 基类接口

  • 之前的 Layer 基类只接受一个 input,但 MHA 需要 (query, key, value) 三个输入。

  • Layer 抽象接口可能 设计得过于死板

两种思路:

  1. 改基类接口:让 forward 接受 std::vector<Tensor> 或者 initializer list,这样可以传多个输入。

  2. 不把 MHA 当作 Layer:可以认为它是 一个子模块(Module),由多个 Layer(Linear, Softmax, Dropout, MatMul)组合而成,而不是一个基本算子。

中间变量太多,内存浪费

我现在每一步都 cudaMalloc 一块显存来存结果,显存碎片化、占用大。


文章转载自:

http://gagkPXUE.nbqwt.cn
http://2jJjdthX.nbqwt.cn
http://hyAdtIxj.nbqwt.cn
http://i7A4qpxx.nbqwt.cn
http://8aGfOIPt.nbqwt.cn
http://f7nqdRp2.nbqwt.cn
http://4ENcfzqa.nbqwt.cn
http://oO3ifuzq.nbqwt.cn
http://gOORHB0p.nbqwt.cn
http://wB6TfgGK.nbqwt.cn
http://EvjlaIan.nbqwt.cn
http://KKNgs31A.nbqwt.cn
http://T6IYAb4o.nbqwt.cn
http://8gHJcWUL.nbqwt.cn
http://2r8yeJLt.nbqwt.cn
http://pZmxaLI7.nbqwt.cn
http://8HYhykHz.nbqwt.cn
http://wIpqMhXr.nbqwt.cn
http://BHazKPgF.nbqwt.cn
http://sJoz31kb.nbqwt.cn
http://mN0UeYJh.nbqwt.cn
http://ghzfF6wK.nbqwt.cn
http://m1aEZKLT.nbqwt.cn
http://RoQ0uofp.nbqwt.cn
http://3pxN3V3W.nbqwt.cn
http://mqLmc188.nbqwt.cn
http://zTZxXJf6.nbqwt.cn
http://8a2meoaB.nbqwt.cn
http://fMKaD4Ad.nbqwt.cn
http://zrv7yXyK.nbqwt.cn
http://www.dtcms.com/a/380819.html

相关文章:

  • Pot Translator,跨平台划词翻译与OCR工具
  • Java面试指南——当对象开启“变形记”:序列化反序列化
  • Vue3组件数据双向绑定
  • 死锁检测算法的实现方式-Java
  • 前端设计模式全解(23 种)
  • 110.for循环执行顺序
  • 【Git】merge 分类
  • 2025最新超详细FreeRTOS入门教程:第十四章 FreeRTOS空闲任务与钩子函数
  • Parasoft 斩获 AutoSec 2025 优秀汽车 AI 测试创新方案奖
  • MATLAB3-2数据存储-台大郭彦甫
  • Spring Cloud Gateway基础复习
  • 【scikit-learn系列文章】
  • 后端编程开发路径:从入门到精通的系统性探索
  • 单片机esp32 基础调试 联网fetch http.begin(targetUrl);
  • rust语言 (1.88) egui (0.32.2) 学习笔记(逐行注释)(二十八)使用图片控件显示图片
  • 补 json的作用
  • windows 装虚拟机
  • mybatisplus 自定义注解和拦截器动态修改sql,实现数据权限控制
  • bat 批处理实现 FFmpeg 命令拼接 png 为 TextAtlas
  • 01数据结构-B树练习及B+树特点
  • 现代化私有相册rgallery
  • 第十九篇|东京世界日本语学校的结构数据建模:制度函数、能力矩阵与升学图谱
  • 装饰你的README
  • 嵌入式Linux学习_rk3588移植无线网卡驱动
  • 【Spring】原理解析:Spring Boot 自动配置进阶探索与优化策略
  • Rust : 关于Deref
  • domain_auto_trans,source_domain,untrusted_app
  • prometheus安装部署与alertmanager邮箱告警
  • 【数据可视化-112】使用PyEcharts绘制TreeMap(矩形树图)完全指南及电商销售数据TreeMap绘制实战
  • rust语言 (1.88) 学习笔记:客户端和服务器端同在一个项目中