当前位置: 首页 > news >正文

cuda优化之softmax

一 roofline

roofline分析


内存访问
对于内存访问来说,我们是一次性加载整个向量,然后一次性保存它。
Bytes=2*N*4(每个浮点值占四字节)
浮点计算量
先求max(x)//N次浮点操作,再求x-m//N次浮点操作,exp=e^x//N次浮点操作,s=sum(exp)//N次浮点操作,out=exp/s//N次浮点操作。
FLOPS=5*N

这样得到的信息就是这个算子每加载8字节,我们进行5次浮点运算。

TheoreticalMaximum = (5/8)*理论带宽 1TB/s = 625GFLOPs//理论最高算力

benchmark/batch_size = 128

这里数据看着还是有问题(因为torch和triton实现的kernel表现出来的算力比理论算力都高),所以作者重新分析了一下cuda存储是怎么实现的?

nv的GPU使用了所谓的write-back cache,这本质上意味着在kernel执行期间,我们只写入L2 cache,而全局内存在我们释放缓存块时接收数据。由于我们L2 cache的写入速度远高于global memory的读取速度。唯一的瓶颈就是从global memory中读取,所以我们理论的最大计算强度增加了2倍。(别人的理解:意思是加载到L2上的数据,在内核执行的时候在上面的读和写因为速度很快所以可认为忽略,只计算L2开始从全局读的那部分,2倍也只能说是大概估计)(我的理解:对于内存访问的计算来说,涉及到“一次性加载整个向量,然后一次性保存它”,但是从L2 cache读到shared memory中的速度很快,甚至可以忽略不计,所以主要受限的还是从global memory中加载整个向量,所以内存访问量大概就是N*4,原先的8*N的一半,理论算力就大致变成原来的两倍)

NOTE:这种分析方式的增加仅适用于那些访存(access global memory)数据完全可以放入L2 Cache(example:H800中的L2 cache的大小为64MB)中的kernel,就是说数据量比较小的情况。
对于那些不这样做的情况,我们仍然需要付出多一次访存的代价,所以kernel的计算性能在大的输入尺度上会出现减速的现象。

二 优化过程

2.1 native

#if SOFTMAX_VARIANT == 1block_size = dim3(32, 32, 1);grid_size = dim3(w/32, h/32, 1);AT_DISPATCH_FLOATING_TYPES(x.type(), "softmax_cuda", ([&] {softmax_kernel<scalar_t><<<grid_size, block_size>>>(x.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), w, h);}));
template <typename scalar_t>
__global__ void softmax_kernel(scalar_t* __restrict__ a, scalar_t* __restrict__ b, int w, int h)
{int col = blockIdx.x*blockDim.x + threadIdx.x;int row = blockIdx.y*blockDim.y + threadIdx.y;if (row < h && col < w){float maxval = a[row*w];for (int i = 1; i<w; i++){maxval = fmaxf(maxval, a[row*w + i]);}float divisor = 0.f;for (int i = 0; i<w; i++){divisor += __expf(a[row*w + i] - maxval);}b[row*w + col] = __expf(a[row*w + col]-maxval)/(divisor);}
}

【分析】

每个线程分到一个数据进行处理,

2.2 fast Reduction

  int h = x.size(0);int w = x.size(1);dim3 block_size = dim3(1, BLOCK_DIM_Y, 1);dim3 grid_size = dim3(h, 1, 1);#if SOFTMAX_VARIANT == 2AT_DISPATCH_FLOATING_TYPES(x.type(), "softmax_cuda", ([&] {softmax_kernel2<scalar_t><<<grid_size, block_size>>>(x.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), w, h);}));template <typename scalar_t>
__global__ void softmax_kernel2(scalar_t* __restrict__ a, scalar_t* __restrict__ b, int w, int h)
{int row = blockIdx.x*blockDim.x + threadIdx.x;int ty = threadIdx.y;__shared__ float reduction[BLOCK_DIM_Y]; if (row < h){float maxval = 0;for (int i = ty*BLOCK_DIM_Y; i<min(w, (ty+1)*BLOCK_DIM_Y); i+=1){maxval = fmaxf(maxval, a[row*w + i]);}reduction[ty] = maxval;for(int stride = BLOCK_DIM_Y/2; stride>=1; stride/=2){__syncthreads();if (ty < stride){reduction[ty] = fmaxf(reduction[ty], reduction[ty+stride]);}}__syncthreads();maxval = reduction[0];float divisor = 0.f;for (int i = ty*BLOCK_DIM_Y; i<min(w, (ty+1)*BLOCK_DIM_Y); i+=1){divisor += __expf(a[row*w + i] - maxval);}reduction[ty] = divisor;for(int stride = BLOCK_DIM_Y/2; stride>=1; stride/=2){__syncthreads();if (ty < stride){reduction[ty] = reduction[ty] + reduction[ty+stride];}}__syncthreads();divisor = reduction[0];for (int i = ty; i<w; i+=BLOCK_DIM_Y){b[row*w + i] = __expf(a[row*w + i]-maxval)/divisor;}}
}

 在实际的内存中,数据是以一维数组的形式存储,因此不需要在设置grid和block的时候,对于x和y的分配,不是一定要对应(感觉不是人话,但是不知道如何表达,等我知道怎么表达了,我再来修改)。

每个block分成BLOCK_DIM_Y个thread,每个thread处理BLOCK_DIM_Y个数据

【step1】每个线程处理BLOCK_DIM_Y个数据,并求出这BLOCK_DIM_Y个数据中的最大值

【step2】BLOCK_DIM_Y个数据,两两求最大值

reduction[0];代表一行的最大值 ,且对线程进行分组,后面的线程可以不参与计算,节省功耗

【step3】按照step1和step2的方式求分母

【step4】结果写回

2.3 访存合并

非访存合并模式:

   for (int i = ty*BLOCK_DIM_Y; i<min(w, (ty+1)*BLOCK_DIM_Y); i+=1){maxval = fmaxf(maxval, a[row*w + i]);}

访存合并模式:

for (int i = ty; i<w; i+=BLOCK_DIM_Y){maxval = fmaxf(maxval, a[row*w + i]);}

可以看出来访存合并针对是一个warp,也就是32个线程,两个连续线程访问连续的两个4B数据,是访存合并的基础 

附录

softmax kernel优化(从naive版本到Online Softmax) - 知乎

http://www.dtcms.com/a/279950.html

相关文章:

  • 组件化思想
  • Brooks 低温泵On-Board Cryopump 安装和维护手法Installation and Maintenance Manual
  • aspnetcore Mvc配置选项中的ModelBindingMessageProvider
  • 第二章 基于新版Onenet搭建云服务(stm32物联网)
  • PyTorch中torch.topk()详解:快速获取最大值索引
  • @Resource 注解的空值处理(默认行为与容器实现)
  • 冲刺阶段项目进度压力大,如何组织高效冲刺
  • 大屏搭建多个图表不自适应问题
  • H264编码结构和解析
  • 第四章 uniapp实现兼容多端的树状族谱关系图,剩余组件
  • ESP32 OTA升级详解:使用Arduino OTA库实现无线固件更新
  • HTML 文本格式化标签
  • java--ThreadLocal创建以及get源码解析
  • http常见状态码
  • 苦练Python第18天:Python异常处理锦囊
  • 【论文阅读】Masked Autoencoders Are Effective Tokenizers for Diffusion Models
  • rsyslog简单应用
  • STM32F769I-DISCO 串口调试
  • Linux上基于C/C++头文件查找对应的依赖开发库
  • SAP B1认证资料-题目
  • 分布式系统中实现临时节点授权的高可用性与一致性
  • 哈希扩展 --- 海量数据处理
  • CISSP知识点汇总- 通信与网络安全
  • 15.Python 列表元素的偏移
  • Java学习————————ThreadLocal
  • python Gui界面小白入门学习二
  • python高阶调试技巧,替代print
  • 14.推荐使用 dict.get(key) 而不是 dict[key]
  • redis配置(Xshell连接centos7的基础上)
  • Modbus 开发工具实战:ModScan32 与 Wireshark 抓包分析(一