当前位置: 首页 > news >正文

layernorm backward CUDA优化分析

简述

本文面向拥有CUDA知识背景并有快速实现layernorm backward需求的读者,若想详细了解layernorm backward计算原理、优化细节请移步参考链接中的文章,本文更侧重于代码实现。如有高见请不吝赐教,谢谢!

很多大佬已经对layernorm_bwd原理、优化方法有过详细讲解(参考链接),这里不再赘述,只是对layernorm_bwd常用优化方法代码复现。

1. layernorm_bwd算法原理及cpu实现

  • layernorm_bwd公式推导:
    在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

template<typename T, typename T_ACC>
void layernorm_backward_cpu(T* dinput, T* dweight, T* dbias, T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{for(int b=0; b<batch; b++){for(int i=0; i<seq_len; i++){const T* doutput_offset = doutput + b * seq_len * hidden_dim + i * hidden_dim;T* dinput_offset = dinput + b * seq_len * hidden_dim + i * hidden_dim;const T* input_offset = input + b * seq_len * hidden_dim + i * hidden_dim;const T_ACC mean_val = mean[b * seq_len + i];const T_ACC rstd_val = rstd[b * seq_len + i]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int j = 0; j<hidden_dim; j++){T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[j] * doutput_offset[j];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int j = 0; j<hidden_dim; j++){T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[j] * doutput_offset[j];// gradient to biasdbias[j] += doutput_offset[j];// gradient to weightdweight[j] += norm_bti * doutput_offset[j];// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[j] += dval;}}}
}

2. layernorm_bwd cuda优化方法及实现

2.1 layernorm_bwd

  • 优化方法:v1版本是每个线程计算一行数据,即一共有batch*seq_len个线程,每个线程循环计算hidden_dim个数据;
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel1(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int idx = blockIdx.x * blockDim.x + threadIdx.x;if(idx < batch * seq_len){const T* doutput_offset = doutput + idx * hidden_dim;T* dinput_offset = dinput + idx * hidden_dim;const T* input_offset = input + idx * hidden_dim;const T_ACC mean_val = mean[idx];const T_ACC rstd_val = rstd[idx]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=0; i<hidden_dim; i++){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=0; i<hidden_dim; i++){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];// gradient to biasatomicAdd(&(dbias[i]), doutput_offset[i]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(256, 1);dim3 grid((batch * seq_len) / block.x, 1);util::print_cuda_cfg(grid, block);layernorm_backward_kernel1<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.2 layernorm_fwd_v2

  • 优化方法:v2版本是每个warp计算一行数据,即一共有batch*seq_len个warp,每个warp循环计算hidden_dim个数据;warp内部会通过线程束洗牌指令计算出max值。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = warpSize / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel2(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = warpReduceSum<T>(dnorm_mean);dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];// gradient to biasatomicAdd(&(dbias[i]), doutput_offset[i]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(32, 1);dim3 grid(1, batch * seq_len);layernorm_backward_kernel2<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.3 layernorm_bwd_v3

  • 优化方法:基于v2版本仍采用32个线程计算一行数据,但在此版本中将doutput加载至smem中,避免对global memory多次访问。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = warpSize / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel3(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;extern __shared__ unsigned char tmp_smem[];T *smem = reinterpret_cast<T *>(tmp_smem);if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = warpReduceSum<T>(dnorm_mean);dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){smem[tx] = doutput_offset[i];__syncthreads();T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * smem[tx];// gradient to biasatomicAdd(&(dbias[i]), smem[tx]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * smem[tx]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(32, 1);dim3 grid(1, batch * seq_len);size_t smem_size = sizeof(T) * block.x;layernorm_backward_kernel3<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.4 layernorm_fwd_v4

  • 优化方法:基于v3版本,v4版本让1024个线程循环计算一行。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T>
__device__ __inline__ T blockReduceSum(T val){__shared__ T shared[WARP_SIZE];__shared__ T ret;int warp_id = threadIdx.x / WARP_SIZE;int lane_id = threadIdx.x % WARP_SIZE;val = warpReduceSum(val);if(lane_id == 0){shared[warp_id] = val;}__syncthreads();val = (threadIdx.x < WARP_SIZE) ? shared[threadIdx.x] : (T)(0.0f);val = warpReduceSum(val);if (threadIdx.x == 0){ret = val;}__syncthreads();return ret;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel4(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;extern __shared__ unsigned char tmp_smem[];T *smem = reinterpret_cast<T *>(tmp_smem);if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = blockReduceSum<T>(dnorm_mean);dnorm_norm_mean = blockReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){smem[tx] = doutput_offset[i];__syncthreads();T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * smem[tx];// gradient to biasatomicAdd(&(dbias[i]), smem[tx]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * smem[tx]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
 	dim3 block(1024, 1);dim3 grid(1, batch * seq_len);size_t smem_size = sizeof(T) * block.x;util::print_cuda_cfg(grid, block);layernorm_backward_kernel4<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.5 layernorm_bwd其他优化方法

v4版本的性能瓶颈是对dbias和dweight进行atomicAdd计算,这样对于dbias和dweight每一个内存位置都有batch * seq_len个线程串行的进行累加计算,是较为耗时的操作。因此可以让block(1024, 1)计算多行,先将每个block负责计算行的smem[tx]和norm_bti × smem[tx]结果累加到寄存器中,然后再将多个block存在寄存器中的值进行atomicAdd计算,这样可以减少需要执行atomicAdd线程的数量,减少串行执行操作,从而提升性能。

3. layernorm_bwd 不同版本性能对比

数据类型及规模: FP32 16 64 2048
硬件平台:A100-SXM

layernorm_bwd versioncycle
layernorm_bwd7482424
layernorm_bwd251740
layernorm_bwd253976
layernorm_bwd98369

参考链接

序号链接备注
1https://zhuanlan.zhihu.com/p/694974164layernorm cuda 代码实现
2https://www.jianshu.com/p/db89d62e1974layernorm 反向推导公式
http://www.dtcms.com/a/316388.html

相关文章:

  • Spring Boot 集成 ShardingSphere 实现读写分离实践
  • MySQL数据类型介绍
  • langchain入门笔记01
  • 【nvidia-B200】Ubuntu 22.04 中安装指定版本的 NVIDIA 驱动时出现依赖冲突
  • 亚马逊否定投放全攻略:精准过滤无效流量的底层逻辑与实战体系
  • 【教育教学】人才培养方案制定
  • Erlang notes[1]
  • 贝叶斯统计从理论到实践
  • G1垃圾回收堆内存分配问题
  • 8位mcu控制器的架构特征是什么?有哪些应用设计?
  • 单片机充电的时候电池电压会被拉高,如何检测电压?
  • 深入解析数据结构之顺序表
  • DAO治理合约开发指南:原理与Solidity实现
  • RocketMq如何保证消息的顺序性
  • 图像处理中的锚点含义
  • 【unity实战】使用unity程序化随机生成3D迷宫
  • 大学生入门:抽象 及 接口
  • 零基础人工智能学习规划之路
  • Python基础框架
  • MyBatis基础入门:Java持久层框架的技术深度解析
  • 脚手架开发-准备配置-配置文件的准备项目的一些中间件
  • Vue3 defineAsyncComponent() 函数
  • 客流人数统计准确率↑32%:陌讯多模态时序融合算法实战解析
  • 10.Linux 系统启动原理
  • 12306旅游产品数据抓取:Python+API逆向分析
  • 如何让 RAG 检索更高效?——大模型召回策略全解
  • 跨越系统孤岛:4A架构如何实现企业级一体化协同
  • 从RNN为什么长依赖遗忘到注意力机制的解决方案以及并行
  • chromedp 笔记
  • 同向双指针——滑动窗口