当前位置: 首页 > news >正文

Transformer中核心机制的讲解:位置编码,注意力机制优化!

前文回顾

在上章节中,我们阐述了近年来针对transformer中激活函数和归一化机制的一些优化点。在本文中,我们将阐述关于位置编码和注意力机制的一些优化。

位置编码

位置编码是希望transformer在注意力计算时能知道token中蕴含的位置信息。

传统的位置编码是transformer论文中提出的正弦-余弦位置编码,基于以下公式:

来为输入向量添加位置编码。

近几年针对位置编码有了不同的实现方式:如自学习的位置编码嵌入,旋转位置编码,ALiBi等等。

位置编码可以按编码的特性分为:绝对位置编码和相对位置编码。

接下来依次讲解不同的位置编码:

自学习的位置编码

自学习的位置编码是指首先设置一个位置编码的上限:比如10000,表示一个输入输出过程中最多出现10000个token。

然后随机生成10000组固定维度的向量作为10000个token位置的编码。然后在训练中不停的优化它们。

  • 潜在问题:由于语料中的短句占比远高于特长句,因此position id很高的位置编码难以被训练。其次这种位置编码的扩展性不好。
  • 解决办法:训练时加长上下文窗口或者迁移为旋转位置编码等办法。

旋转位置编码(RoPE)

介绍旋转位置编码之前,请参考我之前的博客提到的关于经典位置编码的一些细节:

Transformer中的核心问题 知识点汇总

在经典位置编码中,针对某个pos的位置编码,是一系列sin,cos交错的三角函数,这些三角函数是用pos乘了一个元角度k计算得到的,其实k就是,所以位置编码的每个维度位置i有着不同的频率/周期,从而可以辅助模型定位到相对位置的精确解。

所以经典位置编码是通过pos直接计算的三角函数,即角度值的三角函数本身,再添加到原本的嵌入向量中,但这一步首先改变了原有的嵌入本身,其次会在后续计算中引入一些噪声信息。举个例子:

A 为嵌入A,其位置编码为a
B 为嵌入B,其位置编码为b
(A + a) * (B + b) = (A * B + a * B + A * b + a * b)
其中,只有a * b才包含相对位置信息,a * B + A * b属于噪声信息

所以旋转位置编码做的不是计算角度值再加到向量中。

而是计算角度值之后,将嵌入旋转对应的角度,从而不改变嵌入中的模长信息。

而旋转操作通过旋转矩阵来完成,对于一个二维的旋转矩阵:

其中m表示位置id,θ表示元角度,这样即可将[q0,q1]向量旋转。

所以给定一个token的位置,其位置编码方式是将每两个数值组成的向量进行一个角度的旋转,而这个所谓的角度计算和经典位置编码的角度选取是一样的,也是用这个元角度 * pos位置。这样就组成了一个大的旋转矩阵。

因为这是一个稀疏矩阵,所以在计算上是可以优化的:

作者在计算上将其优化成了如下的等效实现方式:

这就是现在主流模型在使用的旋转位置编码。

参考资料:图解RoPE旋转位置编码

旋转位置编码的旋转过程在attention计算前进行,这样就能保证只在attention计算时改变嵌入,从而不为其他过程的计算带来噪声。

  • 旋转位置编码相较于经典位置编码,有效改良了外推能力不足的弱点

什么是位置编码的外推能力?

即模型训练时的上下文窗口长度往往很低,此时的位置编码看起来很有效果,但实际推演时的上下文窗口长度可能远大于训练时的窗口长度,此时位置编码方式在高位token编码的稳定性和有效性即为位置编码的外推能力。

注意力线性偏置(ALiBi)

虽然RoPE的外推能力已经比经典位置编码的外推能力更好了,但是RoPE仍然是基于三角函数的位置编码,外推能力仍然不足。因为在高位的高维旋转中,周期太长了,导致长度低的语料的位置编码旋转几乎不会对高维的值产生什么变化,但训练集中长度低的语料占比又很高,这就导致模型难以在高位的高维旋转中提取出相对位置信息。

于是人们提出了ALiBi,这是一种很简单很直观的位置编码方式。

ALiBi直接修改每个头的注意力得分矩阵,并直接添加了包含相对位置信息的掩码矩阵,除此之外ALiBi为每个注意力头的位置掩码矩阵准备了不同的m做为系数。

m的计算源于:

这里可以看到,不同的注意力头的位置编码掩码的系数是不一样的,越靠前的头的系数越大,相对距离越远的token惩罚越高,故专注于局部,越靠后的头的系数越小,可以将远程的信息传递下来。

m的计算代码如下:

def get_slopes(n):
    def slopes_power_of_2(p):
        # ① 先为一个 2 的幂 p 生成等比数列斜率
        start = 2 ** (-8.0 / p)          # 这一行可以写成 2**(-2**-(log2(p)-3))
        return [start ** (i + 1) for i in range(p)]    if log2(n).is_integer():            # n 本身就是 2^k —— 走简单分支
        return slopes_power_of_2(n)    # ② n 不是 2 的幂:
    closest = 2 ** floor(log2(n))       # “向下取最近的 2 的幂”(≤ n)
    #   先为这 closest 个头生成一串斜率
    slopes = slopes_power_of_2(closest)    # ③ 还差 n-closest 个头怎么办?
    #   递归生成 2*closest 个斜率,再隔一个取一次
    #   选前 n-closest 个补到列表里
    extra = get_slopes(2 * closest)[0::2][: n - closest]
    return slopes + extra

可以看到,cloest之后的斜率是重新生成的,重新调用了slopes_power_of_2(),所以这里导斜率不是单调的,但是没有影响,作者通过实验验证了这一点。

注意力机制

传统的注意力计算就是朴素的q,k,v之间的计算:

后续过程中,人们尝试从各种角度对注意力计算过程进行改良:

包括multi-query attention,flash attention等。

Multi-query attention/Group attention

传统的多头注意力机制将Q,K,V分成多个头部并行计算,Grouped-query attention提出将不同的query组共享同一个key/value组,从而减少显存带宽,加速计算。

而Multi-query attention直接更加极端,将所有的query组全都共享同一个key/value组。

但这样也会造成精度的丢失。

Flash Attention

Flash Attention是对Attention计算的分块实现,由于不需要保存中间注意力得分结果,所以大幅加速了注意力的计算。

首先需要介绍背景知识:

GPU的存储分为SRAM,HBM和DRAM,类似与CPU的三级缓存,容量越高,读取速度越低,容量越低,读取速度越高。

传统的注意力计算在GPU上进行的过程往往是将需要的数据从HBM中输入到SRAM中进行计算,然后再把结果写出到HBM,这样来来回回读写造就了attention计算的慢速,因为attention计算大部分操作都是内存等待型,即内存传输数据的速度跟不上芯片计算的速度。

下面用草稿画了个图来描述这一过程:

这里可以观察到,传统的方法计算存在大量的IO时间,导致计算效率低。

Flash attention将QKV相互作用得到的结果O直接一次性在SRAM中完成之后写出到HBM中,节省了许多中间过程的IO,从而提高时间效率。

具体而言,我们可以通过以下草图来了解注意力的分块计算过程:

这里,如果没有softmax,那么这将是一个很简单的分块过程,只需要分别滑动Q,K,V窗口,选择出不同的行和列进行计算就能得到最后的O。

关键是softmax需要等一整行的注意力得分都计算出来了才能做归一化,那么这样初期计算的O中的值就是不正确的。

Flash Attention想到用增量的方法计算softmax。

首先需要了解softmax的计算过程:

这里可以参考:softmax数值稳定性

softmax的增量计算如下图所示:

原论文的伪代码描述如下:

最后diag单位矩阵即是对原先O的输出进行l_1相乘,然后通过新的l逆矩阵来作为原公式中的分母。

相关文章:

  • 郑州网站建设 郑州网站设计深圳最新通告今天
  • 毕业设计论文网seo服务建议
  • 中山小榄网站百度网址安全中心
  • 镇江手机网站制作做网络优化的公司排名
  • 做pc端网站精英扫描图片找原图
  • 福州网站备案网络营销策略的概念
  • 【Python报错】成功解决error: subprocess-exited-with-error:安装lxml模块不再报错
  • 中宇厨卫启动年中品质回馈活动,深化用户体验
  • 京东正式开源 Taro on HarmonyOS C-API 版本,为鸿蒙应用跨端开发提供高性能框架
  • 阿里云Web应用防火墙3.0使用CNAME接入传统负载均衡CLB
  • 阿里云Redhat系Linux修改ssh默认端口
  • 网络安全就业方向与现实发展分析:机遇、挑战与未来趋势
  • 微信小程序 / UNIAPP --- 阻止小程序返回(顶部导航栏返回、左 / 右滑手势、安卓物理返回键和调用 navigateBack 接口)
  • Android14音频子系统 - 系统框架概述
  • 前端路由的基石:深度剖析 Hash 与 History 模式的本质差异与实战抉择
  • Spring:多数据源配置多个事务管理器DEMO
  • 【SpringBoot】⭐️AutoConfiguration配置的前世今生
  • c语言中的浮点类型
  • 细谈QT信号与槽机制
  • spring中的@Cacheable缓存
  • php后台增加权限控制
  • Odoo API 集成:XML-RPC 与 JSON-RPC 的比较
  • RabbitMq中启用NIO
  • 操作系统学习笔记 | 第一章 计算机系统概述
  • “ICU”归来的小鹏,如何抗衡小米YU7?
  • EJB知识