当前位置: 首页 > news >正文

llama源码学习·model.py[3]ROPE旋转位置编码(4)ROPE的应用

一、源码注释

def apply_rotary_emb(
    xq: torch.Tensor, # 查询矩阵
    xk: torch.Tensor, # 键矩阵
    freqs_cis: torch.Tensor, # 旋转嵌入
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    # 首先将xq和xk张量转换为浮点数
    # 然后使用reshape将最后一个维度拆分为两个维度,每个维度都有大小为2,这样做是为了为复数张量提供实部和虚部。
    # 然后,torch.view_as_complex用于从实部和虚部创建复数张量
    
    # *xq.shape[:-1] 是保留原始形状的所有维度,除了最后一个维度。
    # -1 是一个占位符,它告诉PyTorch自动计算这个维度,以保持元素总数不变。
    # 2 是最后一个维度,这是为了为接下来的复数转换做准备。每个复数由两个浮点数表示(实部和虚部),所以最后一个维度是2。
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # 将freqs_cis重新reshape以匹配xq_的形状,以便进行广播运算。
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    
    # 这两行代码将查询和键张量与旋转嵌入相乘,应用位置嵌入。
    # 函数计算xq_和xk_与freqs_cis的元素乘积(这是一个复数乘法),
    # 在复数乘法中,xq_和xk_的实部和虚部会分别与freqs_cis的实部和虚部进行乘法运算。
    # flatten(3) 将两个最后的维度合并回一个维度
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    
    # 函数返回经过旋转嵌入处理的查询和键张量,同时确保它们的数据类型与原始输入相匹配。
    return xq_out.type_as(xq), xk_out.type_as(xk)

二、举例说明

# query矩阵
xq = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  
# key矩阵
xk = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# 频率张量
freqs_cis = torch.tensor([[1.0000+0.0000j], [1.0000+0.0000j]])  

*** xq.shape: *** torch.Size([2, 2, 2])

*** xk.shape: *** torch.Size([2, 2, 2])

freqs_cis.shape: torch.Size([2, 1])

# 首先,apply_rotary_emb函数会将query和key矩阵reshape并转换为复数张量。
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

xq.float().reshape(*xq.shape[:-1], *-*1, 2).shape: torch.Size([2, 2, 1, 2])

xk.float().reshape(*xk.shape[:-1], *-*1, 2).shape: torch.Size([2, 2, 1, 2])

xq_.shape: torch.Size([2, 2, 1])

xk_.shape: torch.Size([2, 2, 1])

# freqs_cis 的形状是 (2, 1),xq_ 的形状是(2, 2, 1), 所以我们需要将freqs_cis形状调整为 (1, 2, 2, 1)
freqs_cis_new = reshape_for_broadcast(freqs_cis, xq_)

freqs_cis_new.shape : freqs_cis_new.shape

# 函数会将输入复数张量与频率张量相乘。
xq_out_complex = xq_ * freqs_cis_new
xk_out_complex = xk_ * freqs_cis_new

xq_out_complex: tensor([[[1.+2.j], [3.+4.j]], [[5.+6.j], [7.+8.j]]])

# 将结果重塑并转换回实数张量。
xq_out = torch.view_as_real(xq_out_complex).flatten(3)
xk_out = torch.view_as_real(xk_out_complex).flatten(3)

xq_out: tensor([[[[ 1., 2.], [ 6., 8.]], [[15., 18.], [28., 32.]]]])

相关文章:

  • Python八字排盘系统实现分析
  • flutter报错:Could not find com.meituan.android.walle:plugin
  • centos7.9 脚本一键升级到openssl-3.4.0,openssh-9.9p1
  • JSON 解析中需要清理的危险字符
  • 解析Collections工具类主要功能
  • css实现报警特效
  • 计算机技术系列博客——目录页(持续更新)
  • UVM stop_sequences详细介绍与举例(含代码示例与注意事项)
  • 【初探数据结构】树与二叉树
  • Java 反射机制
  • 织梦DedeCMS如何获得在列表和文章页获得顶级或上级栏目名称
  • Filter Solutions学习-02 【高级设计】界面介绍
  • AI图像理解技术的演进
  • AI日报 - 2025年3月21日
  • PyTorch深度学习框架60天进阶学习计划-第27天:模型量化原理(一)
  • Web-Machine-N7靶机通关攻略
  • Web-Machine-N7靶机:渗透测试与漏洞挖掘的实战利器
  • 【从古生物代谢到硅基计算:解码技术加速的深层密码
  • Spring Boot中定时任务Cron表达式的终极指南
  • 广东启动“跨境电商+产业带”系列活动 三年打造30个产业振兴样板
  • 大学2025丨专访西湖大学副校长邓力:如何才能培养“不惧未知”的创新者
  • 民间打拐志愿者上官正义遭人身安全威胁,杭州公安:已立案
  • 国际金价下跌,中概股多数上涨,穆迪下调美国主权信用评级
  • 大外交丨3天拿下数万亿美元投资,特朗普在中东做经济“加法”和政治“减法”
  • 南宁一学校发生伤害案件,警方通报:嫌疑人死亡,2人受伤
  • 坚决打好产业生态培育攻坚战!陈吉宁调研奉贤区