当前位置: 首页 > news >正文

旋转位置编码(3)

目录

    • 代码
    • 代码解释
    • 函数 1: `precompute_pos_cis`
      • 代码解析
    • 函数 2: `apply_rotary_emb`
      • 1. `unite_shape(pos_cis, x)`:调整 `pos_cis` 形状
      • 2. 复数化 `xq` 和 `xk`
      • 3. 应用旋转变换
      • 4. 返回结果
    • 总结
      • 1. `precompute_pos_cis(dim, end, theta)`
      • 2. `apply_rotary_emb(xq, xk, pos_cis)`
    • 示例
    • 广播机制的作用
      • 广播规则
        • 广播细节
      • 如果形状不兼容
      • 总结

代码

def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return pos_cis


def apply_rotary_emb(xq, xk, pos_cis):
    def unite_shape(pos_cis, x):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert pos_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return pos_cis.view(*shape)

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    pos_cis = unite_shape(pos_cis, xq_)
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

代码解释

这段代码实现了 旋转位置编码(Rotary Position Embedding, RoPE),用于在 Transformer 自注意力机制中引入相对位置信息,提升模型的捕捉能力。下面详细拆解代码的核心逻辑。


函数 1: precompute_pos_cis

作用:预计算 旋转位置编码的相位因子,即 复数形式的角度信息

def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): 
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 
    t = torch.arange(end, device=freqs.device)  # type: ignore 
    freqs = torch.outer(t, freqs).float()  # type: ignore 
    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64 
    return pos_cis 

代码解析

  1. 计算角频率 (freqs)

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    • dim 是输入的特征维度(比如 dim=16)。
    • torch.arange(0, dim, 2)[: (dim // 2)] 生成 [0, 2, 4, ..., dim-2] 的索引,然后取前 dim//2 个。
    • 计算 freqs = 1.0 / (theta ** (index / dim)),即 指数衰减的频率(类似于 Sinusoidal PE)。
  2. 构造位置 t

    t = torch.arange(end, device=freqs.device)
    
    • t 是一个长度为 end(默认 32K)的整数序列 [0, 1, 2, ..., end-1],表示序列中的位置索引。
  3. 计算 tfreqs 的外积

    freqs = torch.outer(t, freqs).float()
    
    • torch.outer(t, freqs) 计算 时间 t 与角频率 freqs 的外积,得到一个矩阵:
      在这里插入图片描述

      这个矩阵表示每个位置 t_i 对应的角度。

  4. 转换为极坐标的复数表示

    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    
    • torch.polar(r, θ) 创建复数:
      在这里插入图片描述

    • pos_cis复数编码的旋转因子,后续会与查询 (xq) 和键 (xk) 相乘,实现旋转变换。


函数 2: apply_rotary_emb

作用:将预计算的旋转嵌入 pos_cis 应用于 xq(查询)和 xk(键),完成 旋转位置编码的实际应用

def apply_rotary_emb(xq, xk, pos_cis): 
    def unite_shape(pos_cis, x): 
        ndim = x.ndim 
        assert 0 <= 1 < ndim 
        assert pos_cis.shape == (x.shape[1], x.shape[-1]) 
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 
        return pos_cis.view(*shape) 

1. unite_shape(pos_cis, x):调整 pos_cis 形状

作用:将 pos_cis 形状变换,以匹配 xq 的形状,方便广播计算。

  • xq.shape = (B, L, D)(批量、序列长度、特征维度)
  • pos_cis.shape = (L, D)
  • 目标形状:(1, L, D) 适配 xq
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  • 只在 序列维 (L) 和最后一维 (D) 维持原尺寸,其余地方设置为 1,保证广播机制正确执行。

2. 复数化 xqxk

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  • 输入 xq, xk 形状: (B, L, D)
  • reshape(*xq.shape[:-1], -1, 2) 变成 (B, L, D/2, 2),把 D 维度拆成 两个部分,对应于 复数的实部和虚部
  • torch.view_as_complex() 将最后一维 (2,) 转换为 复数格式 (B, L, D/2)

3. 应用旋转变换

pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) 
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) 
  • xq_ * pos_cis 执行复数乘法,相当于对 xq_ 旋转 pos_cis 指定的角度。
  • torch.view_as_real() 将复数转换回实数 (B, L, D/2, 2)
  • .flatten(3) 重新恢复 (B, L, D) 形状。

4. 返回结果

return xq_out.type_as(xq), xk_out.type_as(xk)
  • 确保返回数据类型和 xq, xk 一致。

总结

1. precompute_pos_cis(dim, end, theta)

  • 生成旋转角度信息,以 复数 形式存储位置编码。
  • 公式
    [
    e^{i\theta} = \cos(\theta) + i\sin(\theta)
    ]

2. apply_rotary_emb(xq, xk, pos_cis)

  • 将旋转编码应用到查询 xq 和键 xk
  • 具体步骤:
    1. 调整 pos_cis 形状 适配 xq
    2. 转换 xq, xk 为复数(实部 + 虚部)。
    3. 执行复数乘法(旋转操作)。
    4. 转换回实数,恢复 (B, L, D) 形状。

最终,这个过程增强了注意力机制,使其具有相对位置编码能力,提升了 Transformer 在长序列建模中的表现。


示例

import torch

# 设置超参数
dim = 16  # 特征维度
seq_len = 10  # 序列长度
batch_size = 2  # 批量大小
head_dim = dim // 2  # 假设 head_dim = dim / 2

# 预计算旋转嵌入
pos_cis = precompute_pos_cis(dim, end=seq_len)

# 生成模拟的 xq 和 xk
xq = torch.randn(batch_size, seq_len, dim)  # (B, L, D)
xk = torch.randn(batch_size, seq_len, dim)  # (B, L, D)

# 打印原始维度
print(f"原始 xq 维度: {xq.shape}")
print(f"原始 xk 维度: {xk.shape}")
print(f"预计算 pos_cis 维度: {pos_cis.shape}")


原始 xq 维度: torch.Size([2, 10, 16])
原始 xk 维度: torch.Size([2, 10, 16])
预计算 pos_cis 维度: torch.Size([10, 8])
# 应用旋转嵌入
xq_out, xk_out = apply_rotary_emb(xq, xk, pos_cis)

# 打印转换后的维度
print(f"转换后 xq 维度: {xq_out.shape}")
print(f"转换后 xk 维度: {xk_out.shape}")

转换后 xq 维度: torch.Size([2, 10, 8, 2])
转换后 xk 维度: torch.Size([2, 10, 8, 2])

广播机制的作用

xq_ = torch.randn(2, 10,16) 
pos_cis= torch.randn(1, 10,16) 
print((xq_ * pos_cis).shape)
torch.Size([2, 10, 16])

广播规则

当进行 xq_ * pos_cis 这样的逐元素乘法时,PyTorch 会根据广播规则(broadcasting rules)来扩展 pos_cis 使其匹配 xq_ 的形状。

广播细节
  1. xq_ 形状是 (2, 10, 16)
  2. pos_cis 形状是 (1, 10, 16)

根据广播规则:

  • 1 可以扩展成 2(沿 batch 维度),变成 (2, 10, 16)
  • 1016 维度匹配,不需要改变

所以 pos_cis 自动扩展(2, 10, 16),然后进行逐元素相乘,最终结果形状仍然是 (2, 10, 16)

如果形状不兼容

如果 pos_cis(10, 16),它的 shape 实际上是 (10, 16) → (1, 10, 16),仍然可以广播。
但是如果 pos_cis(10, 1), 它会变成 (1, 10, 1),那么广播到 (2, 10, 16)1 会扩展到 16,仍然能进行逐元素运算。
如果 pos_cis(3, 10, 16),和 xq_2 不匹配,就会报错。

总结

pos_cis 形状 (1, 10, 16) 通过广播扩展到 (2, 10, 16),因此 xq_ * pos_cis 的结果仍然是 (2, 10, 16)

相关文章:

  • HarmonyOS
  • Spring Boot 项目中使用责任链模式实现复杂接口解耦和动态编排(带示例)
  • 前端技术百宝箱
  • Tweak Power:全方位电脑系统优化的高效工具
  • MySQL 与 MongoDB 的区别
  • CAN总线协议攻防实战:从漏洞分析到攻击模拟
  • 衣联网的商品列表页面结构是怎样的?
  • 设计基于锁的并发数据结构_第六章_《C++并发编程实战》笔记
  • 新一代开源数字供应链安全审查与治理平台:悬镜源鉴SCA
  • 版本控制泄露源码 .svn
  • 机器学习数学基础:45.多重响应分析
  • 鸿蒙应用开发-轻松获取http网络请求
  • 【从零开始学习计算机科学】操作系统(七)文件管理
  • Vue3 Pinia 符合直觉的Vue.js状态管理库
  • Trae AI 辅助修复uniapp 微信小程序的Bug
  • DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加列宽调整功能,示例Table14_02带边框和斑马纹的固定表头表格
  • Linux第0节:Linux环境的搭建
  • ES C++客户端安装及使用
  • vue3如何配置环境和打包
  • el-table中slot=“header“和#header的区别
  • 女排奥运冠军宋妮娜:青少年保持身心健康才能走得更远
  • 大外交丨3天拿下数万亿美元投资,特朗普在中东做经济“加法”和政治“减法”
  • 坚决打好产业生态培育攻坚战!陈吉宁调研奉贤区
  • 车建兴被留置:跌落的前常州首富和红星系重整迷路
  • 人民网三评“网络烂梗”:莫让低级趣味围猎青少年
  • 陕西宁强县委书记李宽任汉中市副市长