模拟配置类
class MockConfig:def __init__(self):self.max_position_embeddings = 2048self.rope_theta = 10000.0self.hidden_size = 512self.num_attention_heads = 8self.head_dim = self.hidden_size // self.num_attention_heads self.rope_scaling = None
Qwen3MoeRotaryEmbedding模块
import torch
import torch.nn as nndef default_rope_init(config, device=None):"""默认的RoPE初始化函数"""dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) print("inv_freq:",inv_freq.shape)return inv_freq.to(device), 1.0 ROPE_INIT_FUNCTIONS = {"default": default_rope_init,
}class Qwen3MoeRotaryEmbedding(nn.Module):inv_freq: torch.Tensor def __init__(self, config: MockConfig, device=None):super().__init__()if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))else:self.rope_type = "default"self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False)self.original_inv_freq = self.inv_freq @torch.no_grad()def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling print("cos:",cos)print("sin:",sin)return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
示例
config = MockConfig()rope = Qwen3MoeRotaryEmbedding(config)batch_size = 2
seq_length = 8
num_heads = config.num_attention_heads
head_dim = config.head_dim q = torch.randn(batch_size, seq_length, num_heads, head_dim)
k = torch.randn(batch_size, seq_length, num_heads, head_dim) position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1) cos, sin = rope(q, position_ids)print(f"\nRoPE输出:")
print(f" - cos: {cos.shape}")
print(f" - sin: {sin.shape}")
inv_freq: torch.Size([32])
cos: tensor([[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],[ 0.5403, 0.7318, 0.8460, ..., 1.0000, 1.0000, 1.0000],[-0.4161, 0.0709, 0.4315, ..., 1.0000, 1.0000, 1.0000],...,[ 0.2837, -0.8209, -0.9461, ..., 1.0000, 1.0000, 1.0000],[ 0.9602, -0.2114, -0.9731, ..., 1.0000, 1.0000, 1.0000],[ 0.7539, 0.5114, -0.7004, ..., 1.0000, 1.0000, 1.0000]],[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],[ 0.5403, 0.7318, 0.8460, ..., 1.0000, 1.0000, 1.0000],[-0.4161, 0.0709, 0.4315, ..., 1.0000, 1.0000, 1.0000],...,[ 0.2837, -0.8209, -0.9461, ..., 1.0000, 1.0000, 1.0000],[ 0.9602, -0.2114, -0.9731, ..., 1.0000, 1.0000, 1.0000],[ 0.7539, 0.5114, -0.7004, ..., 1.0000, 1.0000, 1.0000]]])
sin: tensor([[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00],[ 8.4147e-01, 6.8156e-01, 5.3317e-01, ..., 2.3714e-04,1.7783e-04, 1.3335e-04],[ 9.0930e-01, 9.9748e-01, 9.0213e-01, ..., 4.7427e-04,3.5566e-04, 2.6670e-04],...,[-9.5892e-01, -5.7113e-01, 3.2394e-01, ..., 1.1857e-03,8.8914e-04, 6.6676e-04],[-2.7942e-01, -9.7740e-01, -2.3037e-01, ..., 1.4228e-03,1.0670e-03, 8.0011e-04],[ 6.5699e-01, -8.5931e-01, -7.1372e-01, ..., 1.6600e-03,1.2448e-03, 9.3346e-04]],[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00],[ 8.4147e-01, 6.8156e-01, 5.3317e-01, ..., 2.3714e-04,1.7783e-04, 1.3335e-04],[ 9.0930e-01, 9.9748e-01, 9.0213e-01, ..., 4.7427e-04,3.5566e-04, 2.6670e-04],...,[-9.5892e-01, -5.7113e-01, 3.2394e-01, ..., 1.1857e-03,8.8914e-04, 6.6676e-04],[-2.7942e-01, -9.7740e-01, -2.3037e-01, ..., 1.4228e-03,1.0670e-03, 8.0011e-04],[ 6.5699e-01, -8.5931e-01, -7.1372e-01, ..., 1.6600e-03,1.2448e-03, 9.3346e-04]]])RoPE输出:- cos: torch.Size([2, 8, 64])- sin: torch.Size([2, 8, 64])
应用RoPE到查询和键
def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed
q_rotated, k_rotated = apply_rotary_pos_emb(q, k, cos, sin)print(f"\n应用RoPE后:")
print(f" - 旋转后的查询 (q_rotated): {q_rotated.shape}")
print(f" - 旋转后的键 (k_rotated): {k_rotated.shape}")
print(f"\n=== RoPE性质验证 ===")
assert q_rotated.shape == q.shape, "查询张量形状不一致"
assert k_rotated.shape == k.shape, "键张量形状不一致"print("✓ 查询和键张量形状保持一致")
print(f"\n=== 不同位置的RoPE值示例 ===")
print("位置0的cos值前5维:", cos[0, 0, :5].tolist())
print("位置0的sin值前5维:", sin[0, 0, :5].tolist())
print("位置3的cos值前5维:", cos[0, 3, :5].tolist())
print("位置3的sin值前5维:", sin[0, 3, :5].tolist())
print(f"\n=== 正交性验证 ===")
original_inner_prod = torch.sum(q[0, 0, 0, :] * q[0, 1, 0, :])
rotated_inner_prod = torch.sum(q_rotated[0, 0, 0, :] * q_rotated[0, 1, 0, :])print(f"位置0和1的原始内积: {original_inner_prod:.6f}")
print(f"位置0和1的旋转后内积: {rotated_inner_prod:.6f}")
print(f"差异: {abs(original_inner_prod - rotated_inner_prod):.6f}")print(f"\n=== 示例完成 ===")
print("RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码")
应用RoPE后:- 旋转后的查询 (q_rotated): torch.Size([2, 8, 8, 64])- 旋转后的键 (k_rotated): torch.Size([2, 8, 8, 64])=== RoPE性质验证 ===
✓ 查询和键张量形状保持一致=== 不同位置的RoPE值示例 ===
位置0的cos值前5维: [1.0, 1.0, 1.0, 1.0, 1.0]
位置0的sin值前5维: [0.0, 0.0, 0.0, 0.0, 0.0]
位置3的cos值前5维: [-0.9899924993515015, -0.6279267072677612, -0.11596616357564926, 0.3009673058986664, 0.5827536582946777]
位置3的sin值前5维: [0.14112000167369843, 0.7782725095748901, 0.9932531714439392, 0.9536344408988953, 0.8126488924026489]=== 正交性验证 ===
位置0和1的原始内积: -1.464770
位置0和1的旋转后内积: -1.464770
差异: 0.000000=== 示例完成 ===
RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码