当前位置: 首页 > news >正文

flash attention2 计算过程的探索和学习

之前初步探索了flash attention相比标准attention性能提升的原因。

https://blog.csdn.net/liliang199/article/details/151789333

这里尝试进一步探索,解析flash attention2相比flash attention的性能提升的原因。

首先了解稳定版softmax、以及基于稳定版softmax的标准attention,然后参考论文,使用公式表示flash attention计算的计算过程,进而引出flash attention2的计算过程。

1 稳定版softmax

考虑到数值过大,指数运算后可能会溢出,针对如下所示的softmax计算过程,一般会减去x中最大元素的值。

处理前后的softmax公式如下

softmax(x) = \frac{e^{x_i}}{\sum e^{x_j}} = \frac{e^{x_i-x_{max}}}{\sum e^{x_j-x_{max}}}

由于是减去x_max后做指数运算,相当于分子和分母同时除以exp(x_max),不改变softmax特性。

2 标准attention

softmax是attention最核心的运算,这里设定seq_len=2,查询Q、键K=[K1, K2]和值V。

针对K1和K2,S1和S2表示如下。

S^{(1)} = Q (K^{(1)})^{T}, S^{(2)} = Q (K^{(2)})^{T}

基于稳定版softmax,attention计算过程如下。

m = max(rowmax(S^{(1)}), rowmax(S^{(2)})) \in \mathbb{R}^{B_r} \\ l = rowsum(e^{S^{(1)}-m}) + rowsum(e^{S^{(2)}-m}) \in \mathbb{R}^{B_r} \\ P = \left[P^{(1)}, P^{(2)} \right] = diag(l)^{-1} \left[e^{S^{(1)}-m}, e^{S^{(2)}-m}\right] \in \mathbb{R}^{B_r \times 2B_c} \\ O = \left[P^{(1)}, P^{(2)} \right] \begin{bmatrix} V^{(1)}\\V^{(2)} \end{bmatrix} \\= diag(l)^{-1} (e^{S^{(1)}-m} V^{(1)} + e^{S^{(2)}-m} V^{(2)}) \in \mathbb{R}^{B_r \times 2B_c}

3 flash attention 

flash attention的计算过程示例如下

\displaystyle m^{(1)} = rowmax(S^{(1)}) \in \mathbb{R}^{B_r} \\ l^{(1)} = rowsum(e^{S^{(1)}-m^{(1)}}) \in \mathbb{R}^{B_r} \\ \hat{P} ^{(1)} = diag(l^{(1)})^{-1} e^{S^{(1)}-m^{(1)}} \in \mathbb{R}^{B_r \times B_c} \\O^{(1)} = \hat{P} ^{(1)} V^{(1)} = diag(l^{(1)})^{-1}e^{S^{(1)}-m^{(1)}}V^{(1)} \in \mathbb{R}^{B_r \times d} \\ m^{(2)} = max(m^{(1)}, rowmax(S^{(2)})) = m \\ l^{(2)} = e^{m^{(1)} - m^{(2)}} l^{(1)} + rowsum(e^{S^{(2)}-m^{(2)}}) \\ \hat{P} ^{(2)} = diag(l^{(2)})^{-1} e^{S^{(2)}-m^{(2)}} \\ O^{(2)} = diag(l^{(1)}/l^{(2)}) diag(e^{m^{(1)}-m^{(2)}})O^{(1)} + \hat{P} ^{(2)} V^{(2)} \\= diag(l^{(2)})^{-1} e^{S^{(1)}-m}V^{(1)} + diag(l^{(2)})^{-1} e^{S^{(2)}-m} V^{(2)} = O

如果不考虑减去最大值,flash attention的计算过程如下所示。

4 flash attention2 

flash attention在计算的过程中,每计算一个块,需要使用最新m更新l。

参考softmax公式,l主要是对一化。

flash attention2采用了归一化后移的处理方法,即在计算过程中暂时不做归一化,将其放到最后。

计算过程示例如下。

\displaystyle m^{(1)} = rowmax(S^{(1)}) \in \mathbb{R}^{B_r} \\ l^{(1)} = rowsum(e^{S^{(1)}-m^{(1)}}) \in \mathbb{R}^{B_r} \\ \hat{P} ^{(1)} = diag(l^{(1)})^{-1} e^{S^{(1)}-m^{(1)}} \in \mathbb{R}^{B_r \times B_c} \\ \hat{O}^{(1)} = e^{S^{(1)}-m^{(1)}}V^{(1)} \in \mathbb{R}^{B_r \times d} \\ m^{(2)} = max(m^{(1)}, rowmax(S^{(2)})) = m \\ l^{(2)} = e^{m^{(1)} - m^{(2)}} l^{(1)} + rowsum(e^{S^{(2)}-m^{(2)}}) =l \\ \hat{P} ^{(2)} = diag(l^{(2)})^{-1} e^{S^{(2)}-m^{(2)}} \\ \hat{O}^{(2)} = diag(e^{m^{(1)}-m^{(2)}})\hat{O}^{(1)} + \hat{P} ^{(2)} V^{(2)} \\= e^{S^{(1)}-m}V^{(1)} + e^{S^{(2)}-m} V^{(2)} \\O^{(2)} = diag(l^{(2)})^{-1} \hat{O}^{(2)} = O

这种处理方式,有效减少了如此可降低计算复杂程度,减少了中间数据P1的存储需求。

以下是flash attention2的算法示例

附录

1 diag的含义

diag(a,b,c…)表示一个对角矩阵,即除主对角线外的元素都为0,diag有如下特性

diag(l)^{-1} = diag(l^{-1})

示例程序如下

import numpy as np
a = np.arange(1, 4)
diag_a = np.diag(a)
diag_a_reverse = 1 / np.diag(a)
print(f"diag_a: {diag_a}\ndiag_a_reverse: {diag_a_reverse}")

输入如下

diag_a: [[1 0 0]
 [0 2 0]
 [0 0 3]]
diag_a_reverse: [[1.                inf        inf]
 [       inf 0.5               inf]
 [       inf        inf 0.33333333]]

2 论文1/diag(l1/l2)问题

https://arxiv.org/pdf/2307.08691

flash attention2论文如下所示的公式中1/diag(l1/l2)可能不对,应该为diag(l1/l2),不需要取反。

另外,这一步只考虑了分母项的适配,没有考虑分子项的适配,应该还需要乘exp(m1-m2)项。

由于大部分网络资料沿用论文所示公式,所以还需要继续求证。

reference

---

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

https://arxiv.org/pdf/2307.08691

flash attention计算过程的探索和学习

https://blog.csdn.net/liliang199/article/details/151789333

FlashAttention2详解(性能比FlashAttention提升200%)

https://zhuanlan.zhihu.com/p/645376942

http://www.dtcms.com/a/399530.html

相关文章:

  • 领域驱动设计的vo、do、dto
  • 画图软件在线纵横seo
  • 求解四阶泛函 u‘‘‘‘ - u‘‘ + f = 0 的驻点及周期性边界条件
  • VC维(Vapnik-Chervonenkis Dimension)的故事:模型复杂度的衡量
  • FM收音机RDS功能深度解析
  • 做网站运营还是翻译郑州市房产信息网官方网站
  • SM2商用密码算法轻量化技术:原理、实践与未来展望
  • 双目视觉的传统立体匹配算法有哪些?
  • 电子商务网站版面布局更改wordpress链接
  • Day28_【深度学习(7)—卷积神经网络CNN】
  • 手机百度网盘登录入口织梦做的网站好优化
  • Al驱动下的智能网联汽车创新与应用专题培训
  • 【Stream API学习】
  • 怎样下载建设银行信用卡网站蓝色科技企业网站模板免费下载
  • ubuntu16安装python3.12
  • 编辑网站教程阜宁县城乡建设局新的官方网站
  • 禅城区做网站策划企业公示信息填报
  • LSTM:长短期记忆网络的原理、演进与应用
  • OpenHarmony 4.0 Release横屏配置
  • 网站开发前端与后端铁汉生态建设有限公司网站
  • 服务器安全基线配置
  • 随机森林算法详解:从原理到实战
  • 数据库回表查询解析:从原理到实战优化
  • 详解单元测试、集成测试、系统测试
  • 企业网站设计要点郑州seo哪家公司最强
  • 互动网站制作wordpress add option
  • wordpress 上传 重命名郑州seo外包平台
  • 【C++实战㊱】解锁C++依赖倒置:从理论到实战的蜕变之旅
  • 项目案例作业2:对案例进行面向对象分析
  • 锤子助手插件功能七十二:对话内图片「一键添加至表情」