当前位置: 首页 > news >正文

Qwen3 中旋转位置编码

模拟配置类


class MockConfig:def __init__(self):self.max_position_embeddings = 2048self.rope_theta = 10000.0self.hidden_size = 512self.num_attention_heads = 8self.head_dim = self.hidden_size // self.num_attention_heads # 512 // 8 = 64self.rope_scaling = None

Qwen3MoeRotaryEmbedding模块

import torch
import torch.nn as nndef default_rope_init(config, device=None):"""默认的RoPE初始化函数"""dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size # 64inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))   # 10000.0 ** (torch.arange(0, 64, 2) / 64) -> 32print("inv_freq:",inv_freq.shape)return inv_freq.to(device), 1.0  # inv_freq, attention_scalingROPE_INIT_FUNCTIONS = {"default": default_rope_init,
}class Qwen3MoeRotaryEmbedding(nn.Module):inv_freq: torch.Tensor  # fix linting for `register_buffer`def __init__(self, config: MockConfig, device=None):super().__init__()# BC: "rope_type" was originally "type"if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))else:self.rope_type = "default"self.max_seq_len_cached = config.max_position_embeddings  # 2048self.original_max_seq_len = config.max_position_embeddings  # 2048self.config = config  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]  # default_rope_initinv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)  # 32  1.0self.register_buffer("inv_freq", inv_freq, persistent=False)self.original_inv_freq = self.inv_freq  # 32@torch.no_grad()def forward(self, x, position_ids):    # 2 8 8 64  / 2 8inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)# 1 32 1 / 2 8 1 -> 2 32 1position_ids_expanded = position_ids[:, None, :].float() # 2 1 8device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"with torch.autocast(device_type=device_type, enabled=False):  # Force float32freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # 2 8 32emb = torch.cat((freqs, freqs), dim=-1) # 2 8 64cos = emb.cos() * self.attention_scaling # 2 8 64sin = emb.sin() * self.attention_scaling # 2 8 64print("cos:",cos)print("sin:",sin)return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

示例

config = MockConfig()rope = Qwen3MoeRotaryEmbedding(config)batch_size = 2
seq_length = 8
num_heads = config.num_attention_heads  # 8 
head_dim = config.head_dim  # 6q = torch.randn(batch_size, seq_length, num_heads, head_dim) # 2 8 8 64
k = torch.randn(batch_size, seq_length, num_heads, head_dim)  # 2 8 8 64position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)  # # 8 -> 1,8 -> 2,8cos, sin = rope(q, position_ids)print(f"\nRoPE输出:")
print(f"  - cos: {cos.shape}")
print(f"  - sin: {sin.shape}")
inv_freq: torch.Size([32])
cos: tensor([[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],[ 0.5403,  0.7318,  0.8460,  ...,  1.0000,  1.0000,  1.0000],[-0.4161,  0.0709,  0.4315,  ...,  1.0000,  1.0000,  1.0000],...,[ 0.2837, -0.8209, -0.9461,  ...,  1.0000,  1.0000,  1.0000],[ 0.9602, -0.2114, -0.9731,  ...,  1.0000,  1.0000,  1.0000],[ 0.7539,  0.5114, -0.7004,  ...,  1.0000,  1.0000,  1.0000]],[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],[ 0.5403,  0.7318,  0.8460,  ...,  1.0000,  1.0000,  1.0000],[-0.4161,  0.0709,  0.4315,  ...,  1.0000,  1.0000,  1.0000],...,[ 0.2837, -0.8209, -0.9461,  ...,  1.0000,  1.0000,  1.0000],[ 0.9602, -0.2114, -0.9731,  ...,  1.0000,  1.0000,  1.0000],[ 0.7539,  0.5114, -0.7004,  ...,  1.0000,  1.0000,  1.0000]]])
sin: tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,0.0000e+00,  0.0000e+00],[ 8.4147e-01,  6.8156e-01,  5.3317e-01,  ...,  2.3714e-04,1.7783e-04,  1.3335e-04],[ 9.0930e-01,  9.9748e-01,  9.0213e-01,  ...,  4.7427e-04,3.5566e-04,  2.6670e-04],...,[-9.5892e-01, -5.7113e-01,  3.2394e-01,  ...,  1.1857e-03,8.8914e-04,  6.6676e-04],[-2.7942e-01, -9.7740e-01, -2.3037e-01,  ...,  1.4228e-03,1.0670e-03,  8.0011e-04],[ 6.5699e-01, -8.5931e-01, -7.1372e-01,  ...,  1.6600e-03,1.2448e-03,  9.3346e-04]],[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,0.0000e+00,  0.0000e+00],[ 8.4147e-01,  6.8156e-01,  5.3317e-01,  ...,  2.3714e-04,1.7783e-04,  1.3335e-04],[ 9.0930e-01,  9.9748e-01,  9.0213e-01,  ...,  4.7427e-04,3.5566e-04,  2.6670e-04],...,[-9.5892e-01, -5.7113e-01,  3.2394e-01,  ...,  1.1857e-03,8.8914e-04,  6.6676e-04],[-2.7942e-01, -9.7740e-01, -2.3037e-01,  ...,  1.4228e-03,1.0670e-03,  8.0011e-04],[ 6.5699e-01, -8.5931e-01, -7.1372e-01,  ...,  1.6600e-03,1.2448e-03,  9.3346e-04]]])RoPE输出:- cos: torch.Size([2, 8, 64])- sin: torch.Size([2, 8, 64])

应用RoPE到查询和键

def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # # 2 8 8 64cos = cos.unsqueeze(unsqueeze_dim)  # 2 8 64 -> 2 1 8 64sin = sin.unsqueeze(unsqueeze_dim)  # 2 8 64 -> 2 1 8 64q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

q_rotated, k_rotated = apply_rotary_pos_emb(q, k, cos, sin)print(f"\n应用RoPE后:")
print(f"  - 旋转后的查询 (q_rotated): {q_rotated.shape}")
print(f"  - 旋转后的键 (k_rotated): {k_rotated.shape}")# 6. 验证RoPE的性质
print(f"\n=== RoPE性质验证 ===")
# 检查形状是否保持一致
assert q_rotated.shape == q.shape, "查询张量形状不一致"
assert k_rotated.shape == k.shape, "键张量形状不一致"print("✓ 查询和键张量形状保持一致")# 7. 展示不同位置的RoPE值
print(f"\n=== 不同位置的RoPE值示例 ===")
print("位置0的cos值前5维:", cos[0, 0, :5].tolist())
print("位置0的sin值前5维:", sin[0, 0, :5].tolist())
print("位置3的cos值前5维:", cos[0, 3, :5].tolist())
print("位置3的sin值前5维:", sin[0, 3, :5].tolist())# 8. 验证正交性 (RoPE保持内积不变)
print(f"\n=== 正交性验证 ===")
# 计算原始查询和旋转后查询的内积
original_inner_prod = torch.sum(q[0, 0, 0, :] * q[0, 1, 0, :])
rotated_inner_prod = torch.sum(q_rotated[0, 0, 0, :] * q_rotated[0, 1, 0, :])print(f"位置0和1的原始内积: {original_inner_prod:.6f}")
print(f"位置0和1的旋转后内积: {rotated_inner_prod:.6f}")
print(f"差异: {abs(original_inner_prod - rotated_inner_prod):.6f}")print(f"\n=== 示例完成 ===")
print("RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码")
应用RoPE后:- 旋转后的查询 (q_rotated): torch.Size([2, 8, 8, 64])- 旋转后的键 (k_rotated): torch.Size([2, 8, 8, 64])=== RoPE性质验证 ===
✓ 查询和键张量形状保持一致=== 不同位置的RoPE值示例 ===
位置0的cos值前5维: [1.0, 1.0, 1.0, 1.0, 1.0]
位置0的sin值前5维: [0.0, 0.0, 0.0, 0.0, 0.0]
位置3的cos值前5维: [-0.9899924993515015, -0.6279267072677612, -0.11596616357564926, 0.3009673058986664, 0.5827536582946777]
位置3的sin值前5维: [0.14112000167369843, 0.7782725095748901, 0.9932531714439392, 0.9536344408988953, 0.8126488924026489]=== 正交性验证 ===
位置0和1的原始内积: -1.464770
位置0和1的旋转后内积: -1.464770
差异: 0.000000=== 示例完成 ===
RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码

文章转载自:

http://Xt93AYGk.sskns.cn
http://YPddEBWP.sskns.cn
http://7ZUBS0GL.sskns.cn
http://uwOYE5zW.sskns.cn
http://PVI5ZQsi.sskns.cn
http://KtOLxw6C.sskns.cn
http://StzxOFOd.sskns.cn
http://ojoQ3Zod.sskns.cn
http://H01HaRPE.sskns.cn
http://bg4Nn513.sskns.cn
http://5YoexZqW.sskns.cn
http://XZAFPhML.sskns.cn
http://4MLIfNIC.sskns.cn
http://aKvUwvoX.sskns.cn
http://ODNFVxsJ.sskns.cn
http://WentVA5I.sskns.cn
http://FozGgfLm.sskns.cn
http://bRb19t2O.sskns.cn
http://RBmAS4s0.sskns.cn
http://G2ka1yOV.sskns.cn
http://JFheUrvh.sskns.cn
http://b6b8FBFd.sskns.cn
http://Pgj1iIQY.sskns.cn
http://T69BNxqI.sskns.cn
http://s4t5jZeT.sskns.cn
http://XmYAQ0ab.sskns.cn
http://cjiVg7D3.sskns.cn
http://1KpypkB1.sskns.cn
http://FxjkQ29a.sskns.cn
http://bslPhIY1.sskns.cn
http://www.dtcms.com/a/379262.html

相关文章:

  • vue3项目sass全局变量的设置和使用
  • 透彻理解Python环境管理:虚拟环境、Conda、Pyenv和Pipx为何而生
  • 【unity实战】实现在unity3D模型上画线写字涂鸦效果
  • 2025最新超详细FreeRTOS入门教程:第十三章 FreeRTOS临界区与原子操作
  • 玩转Docker | 使用Docker部署dufs文件管理工具
  • 计算机组成原理:定点乘法运算
  • PyQt5 主窗口状态栏实时显示当前路径的实现与分析
  • 利用conda打包/复刻生信环境
  • glide介绍
  • vscode 中通义灵码显示登录过期
  • 【VScode】ssh报错
  • STM32 norflash W25Q64移植FatFS
  • 【Git】版本控制-Gitee
  • Qt常见问题
  • 泛函Φ(u)驻点的方程与边界条件 / 求给定泛函驻点满足的方程及边界条件
  • 统一权限管理平台登录不了怎么办?
  • 中级统计师-统计法规-第四章 统计管理体制
  • java反射(详细教程)
  • 【Leetcode】高频SQL基础题--1327.查找拥有有效邮箱的用户
  • Redis(集群)
  • 吾爱小工具!一键屏蔽流氓软件!
  • 告别网络监控“盲区”!OpManager全新升级解锁轻量监控新纪元!
  • 实验室试管架 | 塑料、金属等多种材质与规格 | 支持多种试管尺寸 | Sigma-Aldrich
  • .net 类库生成的DLL源码混淆加密
  • 北京-测试-入职金融公司第四周-加班&未发现bug
  • Story2Board: A Training-Free Approach for Expressive Storyboard Generation论文
  • 纯`css`轻松防止滚动穿透
  • 30天Java速成计划:从零基础到能刷算法题!
  • 【点云分类】简述对pointnet和pointnet++的理解
  • 【202509新版】Hexo + GitHub Pages 免费部署个人博客|保姆级教程