cuda编程笔记(19)-- Transformer注意力机制的实现
Transformer里,注意力机制是核心组件,本文将用cuda手写注意力机制模块。
由于Encoder和Decoder里的注意力有掩码之分,在本文实现里,统一忽略掩码;
训练时QKV的L都是相同的,但是推理时不同,本文暂且只写forward,所以L也默认都是一样的
先理清注意力层的步骤
- 输入:q,k,v,维度都是[B, L, D]
- 对q,k,v作线性变换,Q=q*q_weight,其他同理,得到Q,K,V,维度依然是[B,L,D];权重的维度为[D,D]
- Q,K,V进行多头切分,维度转变为[B,H,L,d_head]其中D=H*d_head
- 对output进行合并,concat_output维度转回[B,L,D]
- 再最后对concat_output作线性映射,依旧是[B,L,D]*[D,D]=[B,L,D]
线性变换怎么做?
此线性变换,实际上只作用于D这个维度;但是cuda库没用提供三维乘二维的api,但是我们可以将这个变换理解为[B*L,D]*[D,D]如此即可使用cublas库的cublasSgemm作矩阵乘法
多头切分怎么切?
切分逻辑实际上是 [B,L,D] → [B,L,H,d_head]
但是我们却换了个维度访问方式,变为
[B,H,L,d_head]
如果仅仅是访问的话,也确实可以自己手动改一下访问逻辑;但是问题是我们之后要按这个维度逻辑作矩阵的乘法,如果内存不变的话,直接使用cublas的乘法库将会得到错误的结果。所以必须要对内存重新进行排布了。
__global__ void split_heads(const float *in,float *out,int B,int H,int L,int d_head){int idx=blockDim.x*blockIdx.x+threadIdx.x;int D=H*d_head;int total=B*D*L;if(idx>=total) return;//当前线程对应in的下标int d=idx%D;//d是D维度的下标int l=(idx/D)%L;int b=idx/(L*D);//计算头h的下标int h=d/d_head;//d变为d_head维度的下标d=d%d_head;int out_offset=b*(L*D)+h*(L*d_head)+l*d_head+d;out[out_offset]=in[idx];
}
Softmax怎么做?
在以前的文章里其实也出现过Softmax,总共出现了两种做法:
1.一个block,block内的线程用warp加速。但是问题是B*L*D结果可能比较大,一个block内放不下这么多线程。
2.多block,一个block内一个线程,一个block负责一个batch。这种方法的问题是block内部只有一个线程,几乎没有用到GPU的并行性。
为什么不使用多block多线程呢?因为block之间的数据是没办法在一个核函数里做Softmax的,必须在主机上再发起一次Softmax任务。所以得调用两次核函数。
幸好cudnn为我们提供了封装好的Softmax函数,不用我们自己去封装了。
cudnnStatus_t cudnnSoftmaxForward(cudnnHandle_t handle,cudnnSoftmaxAlgorithm_t algo,cudnnSoftmaxMode_t mode,const void *alpha,const cudnnTensorDescriptor_t xDesc,const void *x,const void *beta,const cudnnTensorDescriptor_t yDesc,void *y
);
algo
(算法)
枚举类型 cudnnSoftmaxAlgorithm_t
,决定数值稳定性与速度:
-
CUDNN_SOFTMAX_FAST
-
速度快,但数值稳定性较差。
-
-
CUDNN_SOFTMAX_ACCURATE
-
使用更稳定的计算方式(减去最大值再 exp),常用。
-
-
CUDNN_SOFTMAX_LOG
-
输出 log-softmax(对数形式)。
-
👉 一般训练时用 ACCURATE
,推理可以选 FAST
。
mode
(归一化模式)
枚举类型 cudnnSoftmaxMode_t
,决定 在哪个维度上归一化:
-
CUDNN_SOFTMAX_MODE_INSTANCE
-
在输入张量的最后一维上做 softmax。
-
例如输入
[batch, channels,h,w]
,则对每个样本batch的通道做 softmax(覆盖 C×H×W)。
-
-
CUDNN_SOFTMAX_MODE_CHANNEL
-
在每个通道维度上做 softmax。
-
例如输入
[batch, channels, h, w]
,则对每个位置的channels
做 softmax。
-
那我们的归一化是在哪个维度做的?
这里QKT计算出的维度是[B,H,L,L],Softmax归一化是在最后一个维度上进行的,也就是每L个元素进行一次Softmax;
所以即可以选择[B*H*L,L,1,1]配合CUDNN_SOFTMAX_MODE_CHANNEL;
矩阵乘法怎么乘
乘法的维度变化是[B,H,L,d_head]*[B,H,d_head,L]=[B,H,L,L],并不是所谓完全的转置;那这个乘法要怎么做呢?如果纯手写kernel,将会比较复杂,相当于做了 B*H次矩阵乘法,这个可以留作以后加速的研究;
实际上,cublas有对应的函数给我们使用
cublasSgemmStridedBatched
,这是 cuBLAS 提供的批量矩阵乘法函数,非常适合 多个矩阵同样形状同时做 GEMM 的场景
cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle, // cuBLAS 句柄cublasOperation_t transA, // A 是否转置cublasOperation_t transB, // B 是否转置int m, // C 的行数int n, // C 的列数int k, // 矩阵乘法 A*B 的公共维度const float *alpha, // 缩放系数 αconst float *A, int lda, long long int strideA, // A 指针、ld、跨步const float *B, int ldb, long long int strideB, // B 指针、ld、跨步const float *beta, // 缩放系数 βfloat *C, int ldc, long long int strideC, // 输出 C 指针、ld、跨步int batchCount // 批次数量
);
这里需要理解一下m,n,k和stride的含义
就拿[B,H,L,d_head]*[B,H,d_head,L]=[B,H,L,L]作例子;我们总共算了B*H批的[L,d_head]*[d_head,L]=[L,L]的GEMM
这里m,n,k对于的就是cublasSgemm里矩阵的维度,可以参考cuda编程笔记(11)--学习cuBLAS使用-CSDN博客这篇文章
而对于strideA,含义是:矩阵A连续批之间在内存中的元素数。说人话就是一次矩阵乘法涉及了多少元素,对于A来说肯定是L*d_head
注意这个函数也是列优先访问的,所以对于行优先的矩阵乘法,依然要对参数的位置做一些调整,具体做法依然参考上面的文章
拼接怎么拼
拼接和切分同理了,需要内存转移
//拼接 [B, H, L, d_head] → [B, L, D]
__global__ void concat_heads(const float *in,float *out,int B,int H,int L,int d_head){int idx=blockIdx.x*blockDim.x+threadIdx.x;int total=B*H*L*d_head;if(idx>=total) return;//计算本线程对应in的下标int d=idx%d_head;int l=(idx/d_head)%L;int h=(idx/(d_head*L))%H;int b=idx/(d_head*L*H);//对应的out的下标int D=H*d_head;int out_offset=b*(L*D)+l*D+h * d_head + d;out[out_offset]=in[idx];
}
模块化实现
class MultiHeadAttention{
public:MultiHeadAttention(cublasHandle_t &cublas_handle_,cudnnHandle_t &cudnn_handle_,int batch_,int length_,int d_model_,int head_);void forward(float *q_input_,float *k_input_,float *v_input_){q_input=q_input_,k_input=k_input_,v_input=v_input_;const float alpha=1.0f,beta=0.0f;//Q=q_input*weight //[B*L,D]=[B*L,D]*[D,D] //对输入做线性变换,只作用最后一个维度cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,q_weight,d_model,q_input,d_model,&beta,Q,d_model);cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,k_weight,d_model,k_input,d_model,&beta,K,d_model);cublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,v_weight,d_model,v_input,d_model,&beta,V,d_model);//多头:[B,L,D]->[B,H,L,d_head]split(Q,Q_split,batch,head,length,d_head);split(K,K_split,batch,head,length,d_head);split(V,V_split,batch,head,length,d_head);//Q × Kᵀ:[B, H, L, d_head] × [B, H, d_head, L] = [B, H, L, L]cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_T,CUBLAS_OP_N,length,length,d_head,&alpha,K_split,d_head,length*d_head,Q_split,d_head,length*d_head,&beta,score,length,length*length,batch*head);//÷根号d_headscale_kernel<<<(batch*head*length*length+255)/256,256>>>(score,batch*head*length*length,1.0/sqrt(d_head));//softmaxcudnnSoftmaxForward(cudnn_handle,CUDNN_SOFTMAX_ACCURATE,CUDNN_SOFTMAX_MODE_CHANNEL,&alpha,score_desc,score,&beta,score_desc,score);//output=score*V:[B,H,L,L] × [B,H,L,d_head]=[B,H,L,d_head]cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_head,length,length,&alpha,V_split,d_head,length*d_head,score,length,length*length,&beta,output,d_head,length*d_head,batch*head);//拼接concat_heads<<<(batch*head*length*d_head+255)/256,256>>>(output,concat_output,batch,head,length,d_head);//多头拼接之后,再映射回 d_model 维度。//concat_output=output*concat_weightcublasSgemm(cublas_handle,CUBLAS_OP_N,CUBLAS_OP_N,d_model,batch*length,d_model,&alpha,concat_weight,d_model,concat_output,d_model,&beta,concat_output,d_model);}~MultiHeadAttention(){cudaFree(q_weight);cudaFree(k_weight);cudaFree(v_weight);cudaFree(concat_weight);cudaFree(Q);cudaFree(K);cudaFree(V);cudaFree(Q_split);cudaFree(K_split);cudaFree(V_split);cudaFree(score);cudaFree(output);cudaFree(concat_output);cudnnDestroyTensorDescriptor(score_desc);}
private:cublasHandle_t &cublas_handle;cudnnHandle_t &cudnn_handle;int batch,length,d_model,head,d_head;//[B,L,D]->[B,H,L,d_head]float *q_weight,*k_weight,*v_weight,*concat_weight;//暂时不带偏置float *q_input,*k_input,*v_input;float *Q,*K,*V;float *Q_split,*K_split,*V_split;//切分的中间值float *score,*output,*concat_output;//score是中间Q × Kᵀ的值,形状是[B, H, L, L]cudnnTensorDescriptor_t score_desc;void split(const float *in,float *out,int B,int H,int L,int d_head){split_heads<<<(B*H*L*d_head+255)/256,256>>>(in,out,B,H,L,d_head);}
};MultiHeadAttention::MultiHeadAttention(cublasHandle_t &cublas_handle_,cudnnHandle_t &cudnn_handle_,int batch_,int length_,int d_model_,int head_):
cublas_handle(cublas_handle_),cudnn_handle(cudnn_handle_),batch(batch_),length(length_),d_model(d_model_){d_head=d_model/head;//分配内存cudaMalloc(&q_weight,d_model*d_model*sizeof(float));cudaMalloc(&k_weight,d_model*d_model*sizeof(float));cudaMalloc(&v_weight,d_model*d_model*sizeof(float));cudaMalloc(&concat_weight,d_model*d_model*sizeof(float));cudaMalloc(&Q,batch*length*d_model*sizeof(float));cudaMalloc(&K,batch*length*d_model*sizeof(float));cudaMalloc(&V,batch*length*d_model*sizeof(float));cudaMalloc(&Q_split,batch*length*d_model*sizeof(float));cudaMalloc(&K_split,batch*length*d_model*sizeof(float));cudaMalloc(&V_split,batch*length*d_model*sizeof(float));cudaMalloc(&score,batch*head*length*length*sizeof(float));cudaMalloc(&output,batch*length*d_model*sizeof(float));cudaMalloc(&concat_output,batch*length*d_model*sizeof(float));//初始化init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (q_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (k_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (v_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1init_uniform << <(d_model*d_model + 255) / 256, 256 >> > (concat_weight, d_model*d_model,1, -0.05f, 0.05f);//seed统一设置为1//张量初始化cudnnCreateTensorDescriptor(&score_desc);cudnnSetTensor4dDescriptor(score_desc,CUDNN_TENSOR_NCHW,CUDNN_DATA_FLOAT,batch*head*length,length,1,1);cudaDeviceSynchronize();
}
暂时的不足
只写了 forward,没有 backward
-
在 PyTorch / TensorFlow 这种框架里,写
forward
足够了,因为 autograd 会追踪运算图,自动生成梯度; -
但在 CUDA + C++ 框架里写 Layer,就得 手动实现 backward,否则训练用不了。
多头注意力有多个输入,不太符合 Layer 基类接口
-
之前的
Layer
基类只接受一个input
,但 MHA 需要(query, key, value)
三个输入。 -
Layer 抽象接口可能 设计得过于死板。
两种思路:
-
改基类接口:让
forward
接受std::vector<Tensor>
或者 initializer list,这样可以传多个输入。 -
不把 MHA 当作 Layer:可以认为它是 一个子模块(Module),由多个 Layer(Linear, Softmax, Dropout, MatMul)组合而成,而不是一个基本算子。
中间变量太多,内存浪费
我现在每一步都 cudaMalloc
一块显存来存结果,显存碎片化、占用大。