当前位置: 首页 > news >正文

Flash Attention学习笔记

fast可以增加模型训练速度,memory efficient 显存高效的,exact:和标准attention得到的结果完全一致,并不降低attention精度。IO-Awareness:通过对IO感知的方式来进行训练的整个算法是以改进IO效率达到的。

首先传统的transformer计算过程:

pytorch写的代码在实际显卡上的attention是如何计算的呢。

  • SRAM:特点是极快,但容量小、成本高、占地方。它主要用在处理器芯片内部,作为缓存。

  • HBM:特点是带宽极高,但成本非常高。它主要用在高端GPU和AI加速卡上,作为主内存,为大规模并行计算提供海量数据吞吐。

  • 由于权重矩阵和输入数据量非常大,它们通常存储在HBM中

  • 一旦数据被加载到SRAM上,计算核心就会高速访问SRAM,执行密集的矩阵乘法运算

  • 1. 从HBM加载Q,K到SRAM
    2. 计算出S= QK^T
    3. 将S写到HBM
    4. 将S加载到SRAM
    5. 计算P=softmax(S)
    6. 将P写出到HBM
    7. 从HBM加载P和V到SRAM
    8. 计算O = PV
    9. 把0 写出到HBM

  • 矩阵Q,K,V维度都是N*d,N是序列长度,d是特征维度。
    可以看出中间有很多临时变量的读写,比如S和P矩阵,他们大小都是随着序列长度的平方增长的。中间临时矩阵占用的显存非常大。比如保留中间结果比如SP会占用显存但是还是需要的,因为反向传播需要他们来计算梯度。在模型训练时制约训练速度有两种情况。

在模型训练时制约训练速度有两种情况。
一种情况是compute-Bound,训练速度的瓶颈在于运算,比如对于大的矩阵乘法还有多channel的卷积操作。这些操作都是需要的数据量不大但是计算很复杂。
第二种情况是memory-bound,训练速度的瓶颈在于对HBM数据的读取速度。
从HBM读取数据的速度跟不上运算的速度,算力在等待数据。主要操作有两类:一位是按位的操作比如relu和dropout,还有一类是规约操作比如sum, softmax这些操作都是需要数据很多但是计算相对简单。

attention计算操作主要是memory bound的计算。

上面右侧的图可以看到compute bound的操作比如矩阵乘法占用的时间很短,但是memory bound占据了很长时间。对于memory bound的优化主要通过融合多个操作。

对于Memory-Bound的优化一般是进行fusion融合操作。不对中间结果缓存,减少HBM访问,节约了原来多个操作之间要存取HBM的时间,让多个操作只要存取一次HBM。我们不保存中间结果在反向传播中重新计算。

显存中的存储是分级的,有芯片内的缓存SRAM(缓存容量小但是访问快),还有芯片外的HBM缓存(容量大但是访问慢),所以对于优化来说应该尽可能让计算访问芯片内的缓存,尽可能减少访问芯片外HBM的显存。

flash attention 着眼于减少IO量。以及通过访问芯片内缓存而加快IO的速度。

当Q和K矩阵很大时,不分块的传统方法会把大部分时间浪费在等待数据从HBM搬运到SRAM上,GPU强大的计算单元大部分时间在“饿着肚子”等数据。

为了实现避免attention matrix从HBM读写通过以下两点实现的:
1. 通过分块计算,融合多个操作,减少中间结果缓存(到HBM)
2. 反向传播时, 重新计算中间结果。

实现了2-4倍速度提升,10-20倍显存占用的节省(从原来的随序列长度平方增长减小到随序列长度线性增长)。

下面我们来看如何通过矩阵分块和融合多个计算来减少对HBM的访问。
暂时先跳过softmax的操作比较特殊后面单独讨论。

1. 从HBM中读取Q的前两行、K转置的前三列、V的前三行,然后传入到SRAM上对他们进行计算。

在SRAM中Q和K的转置得到S并不存入HBM,直接和V的分块进行计算。得到了O的前两行:

因为O是对所有V的一个加权平均,目前得到的结果就是对V的前三行进行加权平均。O用浅色表示因为还只是一个中间结果,后面还需要更新。

2. 接下来K和V的分块还保留在SRAM里,从HBM里读取Q的中间两行,经过同样的计算得到O的第三行和第四行的中间结果。

然后仍然保留K和V的分块在SRAM里

3, 从HBM中读取Q的最后两行经过同样的计算得到O的最后两行的结果

4. 接下来读取K转置的后三列,V的后三行, Q的前两行,得到结果S之后再凑个HBM里读取O之前的保存结果,也就是对V的前三行的加权平均值进行加和。

得到了O前两行的最终结果。

同样保持K和V分块不变,从HBM里读取下一个分块的Q进行计算,从HBM里读取之前的计算中间结果O和更新后存入HBM。最后继续保持SRAM里的K和V分块不变。最后从HBM里读取Q的最后两行进行计算,继续保持SRAM里面的K和V分块不变。加和更新最终存入HBM。

以上完成了attention计算。

通过将矩阵分块以及将多步计算进行融合,中途没有将中间计算结果S存入HBM。大大减少了IO的时间。

接下来看softmax:

softmax是按行进行的,只有一行所有的数据都计算完成后才能进行这里的求和计算。所以我们想要让我们之前的矩阵分块对attention多步进行融合计算得以进行的前提必须解决softmax分块计算问题。

softmax的分块计算:

现在我们训练都是混合精度,在FP16下进行,如果X= 12,则E的X次方就大于FP16所能表示的最大的数了。

为了解决这个数值溢出的问题,人们提出了一种叫做safe softmax的算法。
首先找出从X1到Xn里最大的值m,然后将softmax的分子和分母同时除以E的m次方,softmax结果不变,得到的式子中可以看出e的指数部分就都小于等于0了,这时候用fp16表示就不会有数值溢出的问题了。

在看一下safe softmax的过程:

有一组X通过max(x)求出X里的最大值,通过p(x)将x变化成e的(xi-m(x)),如下图

对于原始2N个X的正确的softmax的值如上图右侧计算过程。

其中p(x)拼接起来的公式是

需要分别给p(x^{1})p(x^{2})一个系数。因为m(x)是m(x^{1})m(x^{2})的最大值,所以m(x)肯定等于其中一个。假设m(x)=m(x^{2})那么后面e的指数项就等于0,那么p(x) = [e^{m(x^{1})-m(x)}p(x^{1}), p(x^{2})]

因为p(x^{2})分块计算的时候减去就是全局最大值,所以此时不需要再进行调整。p(x^{1})在分块计算的时候减去的是局部最大值不是全局最大值,那么它和全局最大值比少了多少呢?就是这里的m(x^{1})-m(x)给他补回来。

所以softmax也可以通过分块来计算了,只是我们需要额外补充几个变量。

可以看到flash attention计算量增加了 但是对HBM访问量大幅度减小,训练时间也是大幅度减小。

flash attention2大致思想相似增加了一些工程优化,

如果只看单次传输,把一大块数据分成很多小块传输,总的数据量不变,总的理论带宽时间应该是一样的。但是,Flash Attention 分块计算能极大减少 IO 时间的根本原因在于:它通过巧妙的分块,避免了将整个庞大的中间结果矩阵反复从慢速内存(如 HBM)读取和写入

http://www.dtcms.com/a/410644.html

相关文章:

  • 解决 QGraphicsDropShadowEffect 导致的 UI 持续刷新
  • 用 LoRA 微调 Qwen3-0.6B 模型,打造专属宠物商店智能客服
  • 建搜索引擎网站衡东网络推广公司
  • Go test 命令完整指南:从基础到高级用法
  • apifox认证登录自动化
  • 江西网站建设哪家专业女装wordpress
  • IDEA JVM优化配置idea64.vmoptions - 保守兼容版本 兼容IDEA 2023.3.6版本【亲测可用】
  • 网站图片像素多少做视频有赚钱的网站
  • APT攻击:隐蔽战场的威胁与防御之道
  • 小兔鲜项目
  • 黑马点评学习笔记01(手机号校验(正则表达式))
  • 声明式事务7
  • 外贸专业网站制作昆明建设网站哪家好
  • 鸿蒙原生contact.queryContacts通讯录查询实现
  • 根据百度地图做网站太原h5建站
  • 【JAVA】从入门到放弃-02-工具、类型、输入输出
  • 伪静态怎么设置(详细教程)
  • 【leetcode】57. 插入区间
  • 多sheet excel 导出
  • 手机移动端网站是什么用什么软件做网站布局
  • cesium-kit:让 Cesium 开发像写 UI 组件一样简单
  • 电子工程师网站wordpress the ken
  • Nginx HTTPS 深入实战 配置、性能与排查全流程(Nginx https
  • 网站建设和优化的营销话术亚马逊雨林生存游戏手机下载
  • 一场“无感换心”手术:金仓数据库如何让电子证照系统平滑告别MongoDB
  • 【开源】基于STM32的新疆地区棉花智能种植系统
  • 高平市规建设局网站短链接生成器
  • 解决SSL握手失败问题:SSLHandshakeException: Received fatal alert: handshake_failure
  • 降级版本Pillow解决freetypefont has no attribute getsize问题
  • 网站设计实例教程wordpress引用文章