当前位置: 首页 > news >正文

手搓多模态-10 旋转位置编码的原理和实现

前情回顾

几个以来一直忙于工作暂时没有时间维护博客文章最近得以忙里偷闲思前想后决定还是时间之前没能写完博客好好完善一下博主最近一些强化学习工作后面也会博客更新强化学习原理

手搓多模态这个专栏我们之前已经完成模型构建工作但是我们提供旋转位置编码实现细节于是本文继续延续之前内容完成旋转位置编码原理实现

旋转位置编码介绍

旋转位置编码相对位置编码一员主要通过旋转方式不同位置的Token嵌入位置标记直接介绍旋转位置编码之前我们首先引入经典位置编码分析从而避免直接讲解旋转位置编码带来突兀

经典位置编码

上面经典位置编码数学计算公式其中 pos 表示Token序列绝对位置2i或者2i+1表示针对某个Token嵌入向量2i(2i+1)表示Token编码嵌入维度

可能诧异为什么位置编码长这样呢?或者为什么这样位置编码行之有效

理解这一点我们需要了解位置编码解决什么问题

Transformer序列首先Embedding转化Token Embedding,每个Token一般会被转换为一个固定维度(如4096维)向量这一过程并行发生然后Token Embedding经历编码解码最后生成一个预测Embedding logits(概率分布)

这些过程目前都是并行发生意味即使输入序列Token随意变换顺序它们参与计算生成概率分布不会发生改变显然我们直觉完全不符合(比如“你与我打招呼”和“我与你打招呼”所表达的语义信息是不同的)因为Token之间目前没有显式编码任何位置相关信息进去

为了Token之间计算具备一些位置信息就需要第一次注意力计算之前位置EmbeddingToken嵌入向量一个比较直观想到方案Token嵌入文章构建一个Token向量维度的且包含位置信息的向量叠加上去可能会说为什么不在Token没被编码时候文章为什么叠加不是相乘 因为这里只是一种方案而已你可以尝试采用方案只是作者采取这种简单直观方案而已

假定某个位置posEmbedding为  , 位置编码向量为

其中f是位置编码函数,故编码位置信息Embedding为

后续Embedding参与注意力计算主要QKV矩阵乘积那么必然会引入, 这里因为位置编码向量所以这里实际在做向量内积。那么假定位置编码函数可以写成以下形式

假定嵌入的维度为2,那么我们对不同Token取内积来看看,设对前pos_1与pos_2的Embedding求内积的结果为w:

这里其实已经可以看出一些眉目因为我们其实希望这里可以反应两个Token相对位置关系这种位置关系假定可以h建模w应当可以写成如下形式

则有:

这种形式公式很容易让人想到三角函数中的和差公式

因此研究人员想到位置编码三角函数编码此时只需要位置编码向量不同维度信息置为关于pos与m三角函数即可。最简单想法直接三角函数角度posm乘积如下所示

这样没法建模我们刚刚形式带入内积计算即可知道

这样得到的映射并非是 的建模,而是 的建模,归根结底因为和差公式要想建模 则在位置编码它们不同导致所以我们必须系数抹除得到以下这样位置编码函数:

前面我们内积考虑嵌入向量维度2情况实际嵌入向量会有512甚至更高维度有人可能直接扩展高维然后重复交替两个就行

这样显而易见很容易扩展维度伴随出现一个问题就是三角函数周期不妨试想,在维度为2的时候,模型通过QKV乘积一些变换得到以下结果

这里实则得到的是,这里的T三角函数周期那么这样其实没办法两个token相对位置信息因为它们直接增加若干三角函数周期都是可以满足这个等式(这里需要考虑计算机处理浮点数的精度问题)所以理论上这样没法精确相对位置信息传递模型

于是研究人员想到扩展高维pos前面加上一个与维度相关的系数假定a,可以如下形式

这样最后模型能够通过每组三角函数(即位置编码向量中相邻的sin函数和cos函数)解读信息这样

这里可以列出这样的线性方程(在位置编码的维度数足够高的情况下)

这里面只有 和k1,k2...不知道,其他信息都是知道的,a是系数(设计的原则是a在不同维度的值是不一样的),T是三角函数周期,y是由三角函数的值反推出的角度,理论上只要维度足够辅助模型求出 ,这里虽然满秩但是因为潜在k是整数条件可以帮助模型判断当然严格数学博主没办法给出这里已经足够理解位置编码那么我们看看这个系数a这么定义

这里可以看到分母随着维度增高增大系数a减小这样导致了位置编码中高维三角函数的系数周期频率维度三角函数系数周期频率位置编码图像我们可以看到这一点

高维(右边的列)变换频率低,(左边的列)变换频率所以将a的分母的底数设置10000为了防止不同pos位置编码恰好相同假如pos_2正好导致Token_2所有维度位置编码的三角函数Token_1位置编码差了整数周期它们位置编码可能完全一样所以为了防止这一点应当使得最高周期足够大,这样必须非常非常可能使得这两个Token最高位置编码相同而当a的底数设置为10000时,最高维的周期为,实际使用时候pos整数Token位置编码碰撞可能非常

旋转位置编码

看到这里相信你已经经典位置编码由来有了一个初步理解后面我们理解旋转位置编码不会

旋转位置编码动机

前面经典位置编码我们通过最开始时候在Token嵌入向量直接添加位置编码公式表示如下所示

这样实际引入噪声,一方面它改变了嵌入的模长,另一方面在QKV计算虽然引入了,但是同样也引入了以及 这样的噪声

其次由于只在最开始时候加入位置编码这就使得位置信号靠后注意力计算非常微弱。而旋转位置编码就是为了解决两个问题提出

针对噪声优化

正如名称所言旋转位置编码想到通过旋转方式不是叠加方式嵌入位置信息这样嵌入向量模长不变改变嵌入向量方向

二维旋转公式如下

其中m一个系数θ角度表示向量q旋转mθ

旋转天生就能反应这种角度之间差值原先经典位置编码针对pos位置编码我们这样设置

这里我们m类比pos, a类比θ(之所以这么理解是因为在位置编码中,自变量是pos,我们是要建模pos与pos之间的关系,当然你可以把pos比作θ,a比作m, 这样反应的其实是同一个Token的位置编码时不同维度的位置编码的变化),很容易想到在经典位置编码通过计算元角度apos得到三角函数编码位置编码,旋转矩阵应该相同维度嵌入进行旋转旋转角度也为apos注意这里的a是每两个值共用同一个a

比如嵌入只有时候编码位置信息可以采用以下公式得到

那么扩展高维就是每一维的数值进行旋转旋转pos * a

那么这样我们得到如下旋转矩阵

为了计算进行优化作者旋转计算写成这种形式

这就是原文公式由来

针对信号减弱优化

既然最开始位置编码导致信息传播减弱问题那就每一次注意力时候做一次旋转即可这样就能保留旋转信息

位置编码实现

我们之前的代码我们通过以下代码应用旋转位置编码

cos,sin = self.rotary_embedding(value_states,position_ids,seq_len = None) #不改变形状,这里value_states只是为了提供device
query_states,key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

我们尚未实现它们这里cossin主要就是构造公式cos向量sin向量,理解原理之后我们实现很简单

首先实现cos,sin向量构建

class GemmaRotaryEmbedding(nn.Module): def __init__(self, dim:int, max_position_embeddings:int = 2048, base:int = 10000, device:torch.device = None):super().__init__()
		self.dim = dim
		self.base = base
		self.max_position_embeddings = max_position_embeddings# 计算公式 theta_i = base**(-(2i / dim)) ,i = 0,1,2,3.....dim // 2
		inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))		self.register_buffer("inv_freq", inv_freq, persistent=False)def forward(self,x , position_ids:torch.Tensor, seq_len:Optional[int] = None):# x shape: [batch_size, num_attention_heads,seq_len, head_dim]		self.inv_freq = self.inv_freq.to(x.device) ##转移设备## 为后面的矩阵乘法做准备 inv_freq:[dim // 2] -> inv_freq_expanded: [batch_size, dim // 2, 1]
		inv_freq_expanded = self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1)## 为后面的矩阵乘法做准备 position_ids:[batch_size, seq_len] -> position_ids_expanded: [batch_size, 1, seq_len]
		position_ids_expanded = position_ids[:, None, :].float()		device_type = x.device.type
		device_type = device_type if isinstance(device_type,str) and device_type != "mps" else "cpu"with torch.autocast(device_type=device_type,enabled=False): ## 禁用自动混合精度## 这里是为了给每个不同位置的token准备一个m * theta, m代表position_id,由position_ids_expanded提供,theta_i 由inv_freq_expanded提供## [batch_size, dim // 2, 1] * [batch_size, 1, seq_len] = [batch_size, dim // 2, seq_len] -> [batch_size, seq_len, dim // 2] 
			freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1,2)## 将 dim // 2的 m * theta_i 扩展到dim维度,这里没有重复交替(m*θ 1,m*θ 1,m*θ 2,m*θ 2,...)是参照hugging face的实现,也能达到同样的效果
			emb = torch.cat((freqs,freqs),dim=-1)			cos = emb.cos()
			sin = emb.sin()return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype)

这里可能比较疑惑这里旋转位置编码为什么论文不同这是因为hugging face导入模型的时候权重发生置换具体原因可以参考 旋转位置编码差异详解 这里不过多赘述,大家只要知道原理就行。

那么至此我们完成了主模型的最后一块拼图。希望本文对各位理解位置编码有一些帮助。


文章转载自:

http://v5sjA2fH.skdhm.cn
http://A8bz9mzC.skdhm.cn
http://8bHy7Gy7.skdhm.cn
http://MxhVVBYJ.skdhm.cn
http://BEcDuiH4.skdhm.cn
http://qU6PLDz7.skdhm.cn
http://AhUncQOo.skdhm.cn
http://WyTNayR7.skdhm.cn
http://Ev4DJqca.skdhm.cn
http://aG4ifLaL.skdhm.cn
http://G1wojecq.skdhm.cn
http://7bVt59z3.skdhm.cn
http://dsbAL3A6.skdhm.cn
http://PG9xHE5H.skdhm.cn
http://qfohRFGE.skdhm.cn
http://vRhLgXfZ.skdhm.cn
http://zX6hoZOK.skdhm.cn
http://DREZGMST.skdhm.cn
http://JG2ZPSXW.skdhm.cn
http://AIOSFz3m.skdhm.cn
http://bFi1U5kv.skdhm.cn
http://xK68jKMT.skdhm.cn
http://VBK5OjOF.skdhm.cn
http://UWneUElQ.skdhm.cn
http://BkX8gW7N.skdhm.cn
http://n7RXTLbQ.skdhm.cn
http://cbV2xn08.skdhm.cn
http://Ylamronz.skdhm.cn
http://W2nUe8PA.skdhm.cn
http://ZAlSt31x.skdhm.cn
http://www.dtcms.com/a/383574.html

相关文章:

  • C# --- dispose机制与using关键字
  • HakcMyVM-Aurora
  • Flask学习笔记(一)
  • MobaXterm软件访问ZYNQ板卡的Linux系统
  • 基于vLLM与YOLO的智能图像分类系统
  • 标准CAN帧介绍
  • 蚂蚁矿机S19 Pro 104T技术参数解析及性能分析
  • 一小时解决RabbitMQ面试题
  • HBM4量产就绪|2026年AI与数据中心新标配
  • 细粒度图像分类的可解释性Finer-CAM
  • C++中多线程core的问题分析和总结
  • scrapy框架-day02
  • 电商导购平台的移动端架构设计:React Native在多端统一中的实践
  • class_9:java 抽象类和接口
  • [硬件电路-209]:电子携带两种能量,一种是电流宏观运动的动能,一种是绕着原子核运动的原子轨道能量;前者是电势能与热能转化的媒介;后者是实现光能与电能的转化
  • HBase启动报错“Master is initializing”解决方案
  • 交换机的级联和堆叠
  • QT加密和哈希
  • 历史数据分析——中科曙光
  • Dropout:深度学习中的随机丢弃正则化技术
  • 数组存储 · 行主序与列主序 | 应用 / 基地址 / 选择策略
  • 贪心算法应用:最早截止时间优先(EDF)问题详解
  • 每天五分钟深度学习:神经网络的权重参数如何初始化
  • BisenetV1/2网络以及模型推理转换
  • Codeforces Round 1050 (Div. 4)补题
  • 【Java后端】Spring Boot 多模块项目实战:从零搭建父工程与子模块
  • c++命名空间详解
  • 第15课:知识图谱与语义理解
  • HarmonyOS图形处理:Canvas绘制与动画开发实战
  • ffmpeg 有什么用处?