当前位置: 首页 > news >正文

【大模型LLM学习】Flash-Attention的学习记录

【大模型LLM学习】Flash-Attention的学习记录

  • 0. 前言
  • 1. flash-attention原理简述
  • 2. 从softmax到online softmax
    • 2.1 safe-softmax
    • 2.2 3-pass safe softmax
    • 2.3 Online softmax
    • 2.4 Flash-attention
    • 2.5 Flash-attention tiling

0. 前言

  Flash Attention可以节约模型训练和推理时间,很多模型可以通过config参数来选择attention是标准的attention实现还是flash-attention方式。在这里记录一下flash attention的学习过程,发现了一位博主以及参考的资料特别好:

  • zhihu一位做高性能计算的博主博文
  • 华盛顿大学的课程note

1. flash-attention原理简述

a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V attention(Q,K,V)=softmax(dk QKT)V
  标准的attention操作的时间卡点不是在运算上,而是卡在数据读写上。SRAM的读写速度快,但是存储空间有限,无法一次存下来所有的中间计算结果,一次attention计算存在SRAM<->HBM的多次读写操作。
在这里插入图片描述
  与标准的attention操作比较,flash-attention通过减少数据在HBM和SRAM间的读写操作,来节约时间(甚至backward时还进行了重新计算,重新计算的速度也比把数据从HBM读取到SRAM要快)。
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention

2. 从softmax到online softmax

  直接看flash-attention的论文比较难看明白,发现华盛顿大学的那份note写得特别清晰,跟着它从softmax看到flash-attention会比较容易。

2.1 safe-softmax

  首先是safe的softmax计算方式。原始的softmax,对于N个数:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N softmax(\{x_1,...,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^{N}e^{x_j}}\right\}_{i=1}^{N} softmax({x1,...,xN})={j=1Nexjexi}i=1N
  对于FP16,最大能表示的数据为65536,当 x > = 11 x>=11 x>=11时, e x e^x ex就会超过FP16的最大表示范围影响结果的正确性。为了避免这个问题,SafeSoftmax 通过减去输入向量中的最大值来调整输入,使得最大的指数项变为 e 0 = 1 e^0=1 e0=1从而防止了上溢的发生。同时,由于所有的指数项都除以同一个数,它们的比例关系不会改变,因此也不会影响最终的概率分布。
e x i ∑ j = 1 N e x j = e x i − m ∑ j = 1 N e x j − m , m = m a x { x j } j = 1 N \frac{e^{x_i}}{\sum_{j=1}{N}e^{x_j}}=\frac{e^{x_i-m}}{\sum_{j=1}{N}e^{x_j-m}}, \quad m=max\left\{x_j\right\}_{j=1}^{N} j=1Nexjexi=j=1Nexjmexim,m=max{xj}j=1N

2.2 3-pass safe softmax

  • 对于一个行向量 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N,最直白的softmax计算方式是直接for循环

在这里插入图片描述
  这个算法计算softmax需要执行3次从1->N的循环,在attention中, { x i } \{x_i\} {xi} Q K T QK^T QKT的结果,但是如果SRAM里面存不下这个大的矩阵,上面的计算过程,就需要从HBM里面加载3次 { x i } \{x_i\} {xi},时间花在了数据读写上。

2.3 Online softmax

  如果能把上面(7)(8)(9)这3个式子的计算放一个for循环,就只需要一次load数据。但是 m N m_N mN是全局最大值,计算 m N m_N mN就已经需要一次遍历了。
  Online softmax算法把(7)(8)进行了合并,把3次遍历缩减为2个。它提出计算 d i ′ = ∑ j = 1 i e x j − m i d_i^{\prime}=\sum_{j=1}^{i}e^{x_j-m_i} di=j=1iexjmi来代替计算 d i d_i di,当算到最后 i = N i=N i=N时会发现, d N = d N ′ d_N=d_N^{\prime} dN=dN。具体的,迭代计算 d i ′ d_i^{\prime} di的方式为:
d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} d_i^{\prime} &= \sum_{j=1}^{i} e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}^{\prime} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi

  所以就可以用迭代的方式,在找最大值 m N m_N mN的时候,同时来计算 d i ′ d_i^{\prime} di,把(7)和(8)一起计算,这样只需要加载两次 x i x_i xi

在这里插入图片描述

2.4 Flash-attention

  上面的online softmax仍然需要2个for循环,加载2次 x i x_i xi来完成softmax的计算。完成softmax的计算,没法更进一步地压缩到1次遍历。但是attention计算的最终目标是获取输出结果,也就是注意力分数与 V V V相乘的结果 O = A × V O=A \times V O=A×V,计算 O O O可以通过一次遍历完成。
在这里插入图片描述
  可以使用类似online softmax把计算 d i d_i di变成计算 d i ′ d_i^{\prime} di的方式,把 o i o_i oi的计算也改成迭代式的,首先把 a i a_i ai带入 o i o_i oi的表达式
o i = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) o_i=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{N}}}{d_N^{\prime}}V[j,:]\right) oi=j=1i(dNexjmNV[j,:])

  可以找到一个 o i ′ o_i^{\prime} oi,它不依赖于全局的 d N ′ d_N^{\prime} dN m N m_N mN
o i ′ = ∑ j = 1 i ( e x j − m i d i ′ V [ j , : ] ) o_i^{\prime}=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{i}}}{d_i^{\prime}}V[j,:]\right) oi=j=1i(diexjmiV[j,:])

  对于 o i ′ o_i^{\prime} oi的计算可以使用迭代的方式,同样的是有 o N = o N ′ o_N=o_N^{\prime} oN=oN
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} o_i' &= \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d_{i-1}'}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j,:] \right) \frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= o_{i-1}' \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \end{aligned} oi=j=1idiexjmiV[j,:]=(j=1i1diexjmiV[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1exjmi1exjmididi1V[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1V[j,:])didi1emi1mi+dieximiV[i,:]=oi1didi1emi1mi+dieximiV[i,:]

  这样计算attention的输出结果可以只进行一次遍历就完成
在这里插入图片描述

2.5 Flash-attention tiling

  上面是每次计算一个元素 [ i ] [i] [i],实际上可以一次读取一个大小为b的块(tile)来计算

在这里插入图片描述在这里插入图片描述

  此外,在flash-attention的paper里面,对 Q Q Q K K K V V V O O O分块,其中 Q Q Q
O O O每块大小为 m i n ( M / 4 d , d ) × d min(M/4d,d) \times d min(M/4d,d)×d K / V K/V K/V的每块大小为 M / 4 d × d M/4d \times d M/4d×d,加起来正好不会超过SRAM的大小M,完整的算法在paper中:
在这里插入图片描述

相关文章:

  • PL/SQLDeveloper中数值类型字段查询后显示为科学计数法的处理方式
  • 主流定位技术:Zigbee、蓝牙、UWB、RFID、5G通信介绍及对比
  • PCB特种工艺应用扩展:厚铜、高频与软硬结合板
  • 视觉SLAM基础补盲
  • 【Kotlin】协程
  • 准确--k8s cgroup问题排查
  • 紫光同创FPGA系列实现Aurora 8b/10b协议
  • 单芯片电流采用电路分享
  • Tongweb配置跳过扫描应用jar的配置指引(by lqw)
  • 山东大学算法设计与分析复习笔记
  • 【PhysUnits】16.2 引入变量后的乘法实现(mul.rs)
  • arduino开发遇见的问题(求助!!)
  • springMVC-10验证及国际化
  • 深入理解 transforms.Normalize():PyTorch 图像预处理中的关键一步
  • 格恩朗液体涡轮流量计 精准测量 助力多元场景
  • 制作一款打飞机游戏67:更好的瞄准
  • Window Server 2019--11 虚拟专用网络
  • 【Java后端基础 005】ThreadLocal-线程数据共享和安全
  • vue3子组件获取并修改父组件的值
  • 企业私有化部署的平价革命:五步实现“低成本高可控”AI落地——破除百万投入迷思,中小企业也能玩转私有化大模型
  • php网站建设与管理/百度搜索平台
  • 在线做插画的网站/seo网络推广是什么意思
  • 黄骅港站/南宁seo多少钱报价
  • 百度收录网站但首页快照不更新/时事新闻最新
  • 网站标题被别人改了 应该怎么办/百度搜索引擎优化的推广计划
  • 济南做网站推广有哪些公司/app推广代理平台