当前位置: 首页 > news >正文

Qwen3 中注意力机制实现

导入必要的库

import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import typing
from transformers.utils import TransformersKwargs
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
Unpack = typing.Unpack

旋转位置编码辅助函数

def rotate_half(x):x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

键值重复函数

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # [2, 4, 8, 64]  2 batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

Eager注意力前向传播函数

def eager_attention_forward(module: nn.Module,query: torch.Tensor, # [2, 8, 8, 64]key: torch.Tensor,  # [2, 4, 8, 64]value: torch.Tensor, # [2, 4, 8, 64]attention_mask: Optional[torch.Tensor],  # [2, 1, 8, 8]scaling: float, # 0.125 dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs],
):key_states = repeat_kv(key, module.num_key_value_groups)  # [2, 8, 8, 64]value_states = repeat_kv(value, module.num_key_value_groups) # [2, 8, 8, 64] attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling   # [2,8, 8, 8]if attention_mask is not None:causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]   # [2, 8, 8,8]    print("causal_mask:",causal_mask.shape)  attn_weights = attn_weights + causal_mask  # [2, 8, 8, 8]attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)print("attn_weights:",attn_weights.shape) # [2, 8, 8, 8]attn_output = torch.matmul(attn_weights, value_states) #  [2, 8, 8, 8] [2, 8, 8, 64]  attn_output = attn_output.transpose(1, 2).contiguous()  # [2, 8, 8, 64]return attn_output, attn_weights

RoPE位置编码实现

def default_rope_init(config, device=None):"""默认的RoPE初始化函数"""dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size # 64inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))   # 10000.0 ** (torch.arange(0, 64, 2) / 64) -> 32print("inv_freq:",inv_freq.shape)return inv_freq.to(device), 1.0  # inv_freq, attention_scaling# 注册RoPE初始化函数
ROPE_INIT_FUNCTIONS = {"default": default_rope_init,
}class Qwen3MoeRotaryEmbedding(nn.Module):inv_freq: torch.Tensor  # fix linting for `register_buffer`def __init__(self, config, device=None):super().__init__()# BC: "rope_type" was originally "type"if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))else:self.rope_type = "default"self.max_seq_len_cached = config.max_position_embeddingsself.original_max_seq_len = config.max_position_embeddingsself.config = configself.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)self.register_buffer("inv_freq", inv_freq, persistent=False)self.original_inv_freq = self.inv_freq@torch.no_grad()def forward(self, x, position_ids):inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)position_ids_expanded = position_ids[:, None, :].float()device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"with torch.autocast(device_type=device_type, enabled=False):  # Force float32freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)emb = torch.cat((freqs, freqs), dim=-1)cos = emb.cos() * self.attention_scalingsin = emb.sin() * self.attention_scalingreturn cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

Qwen3Moe注意力机制实现

class Qwen3MoeAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config, layer_idx: int):super().__init__()self.config = configself.layer_idx = layer_idx                   # 0self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)  # 64self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads  # 2 self.scaling = self.head_dim**-0.5   # 0.125self.attention_dropout = config.attention_dropout #  0.0self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)  # 512 8*64  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)  # 512 4*64self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)  # 512 4*64self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)  # 8*64 512self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  self.sliding_window = getattr(config, "sliding_window", None)def forward(self,hidden_states: torch.Tensor,    # [2, 8, 512] position_embeddings: tuple[torch.Tensor, torch.Tensor], # [2 8]attention_mask: Optional[torch.Tensor],   # [2, 1, 8, 8]past_key_values: Optional = None,cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs],) -> tuple[torch.Tensor, Optional[torch.Tensor]]:input_shape = hidden_states.shape[:-1]   # [2, 8]hidden_shape = (*input_shape, -1, self.head_dim) # [2, 8, -1 , 64] query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)# [2, 8, 512]  [2, 8, 512] -> [2, 8, 8, 64] -> [2, 8, 8, 64]key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)# [2, 8, 512]  [2, 8, 256] -> [2, 8, 4 ,64] -> [2, 4, 8, 64]value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)# [2, 8, 512]  [2, 8, 256] -> [2, 8, 4 ,64] -> [2, 4, 8, 64]cos, sin = position_embeddings  # [2, 8, 64], [2, 8, 64]query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) print("query_states:",query_states.shape) # [2, 8, 4, 64]if past_key_values is not None:cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)attention_interface: Callable = eager_attention_forward  if self.config._attn_implementation != "eager":attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]attn_output, attn_weights = attention_interface(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,sliding_window=self.sliding_window, **kwargs,)  # [2, 8, 8, 64]           [2,8, 8, 8]attn_output = attn_output.reshape(*input_shape, -1).contiguous() # [2, 8, 512]attn_output = self.o_proj(attn_output) # [2, 8, 512]return attn_output, attn_weights  # [2, 8, 512]   [2,8, 8, 8]

模拟配置和RMSNorm实现

@dataclass
class MockConfig:hidden_size: int = 512num_attention_heads: int = 8num_key_value_heads: int = 4head_dim: int = 64max_position_embeddings: int = 2048rope_theta: float = 10000.0rms_norm_eps: float = 1e-6attention_bias: bool = Falseattention_dropout: float = 0.0_attn_implementation: str = "eager"class Qwen3MoeRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)def extra_repr(self):return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

主函数:测试代码

if __name__ == "__main__":# Configurationconfig = MockConfig()attention_layer = Qwen3MoeAttention(config, layer_idx=0)rotary_emb = Qwen3MoeRotaryEmbedding(config)batch_size = 2seq_length = 8hidden_size = config.hidden_size  # 512hidden_states = torch.randn(batch_size, seq_length, hidden_size) # [2, 8, 512] position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)  # [2, 8]cos, sin = rotary_emb(hidden_states, position_ids)print(f"Position embeddings:")print(f"  - cos shape: {cos.shape}")  # [2, 8, 64]print(f"  - sin shape: {sin.shape}")  # [2, 8, 64]attention_mask = torch.tril(torch.ones(batch_size, 1, seq_length, seq_length)) # [2, 1, 8, 8]attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min # [2, 1, 8, 8]attention_output, attention_weights = attention_layer(hidden_states=hidden_states,position_embeddings=(cos, sin),attention_mask=attention_mask  # Now providing the required argument)print(f"\nAttention results:")print(f"  - Input shape: {hidden_states.shape}")      # [2, 8, 512]print(f"  - Output shape: {attention_output.shape}")   # [2, 8, 512]print(f"  - Attention weights shape: {attention_weights.shape}")  # [2, 8, 8, 8]
inv_freq: torch.Size([32])
Position embeddings:- cos shape: torch.Size([2, 8, 64])- sin shape: torch.Size([2, 8, 64])
query_states: torch.Size([2, 8, 8, 64])
causal_mask: torch.Size([2, 1, 8, 8])
attn_weights: torch.Size([2, 8, 8, 8])Attention results:- Input shape: torch.Size([2, 8, 512])- Output shape: torch.Size([2, 8, 512])- Attention weights shape: torch.Size([2, 8, 8, 8])

文章转载自:

http://ikoeWFtQ.skscy.cn
http://8JrarrlI.skscy.cn
http://pPbUq85J.skscy.cn
http://AClU2ZV0.skscy.cn
http://a997ko09.skscy.cn
http://EI3ZLM1P.skscy.cn
http://ipsaM2vV.skscy.cn
http://zkXb5YOO.skscy.cn
http://8gv6sQoM.skscy.cn
http://XwU6FHGJ.skscy.cn
http://IXWTQQzi.skscy.cn
http://UKmgWgo0.skscy.cn
http://fqaqLFmG.skscy.cn
http://TrIoKgJm.skscy.cn
http://yd2fOdaP.skscy.cn
http://nu4RDta9.skscy.cn
http://MSFj0chz.skscy.cn
http://sMBAqSpE.skscy.cn
http://KNYMpN1L.skscy.cn
http://qA797Vo3.skscy.cn
http://zhetL5Se.skscy.cn
http://19ufxDmy.skscy.cn
http://ueVFRSxu.skscy.cn
http://iPFeWysv.skscy.cn
http://Jr1xIX9i.skscy.cn
http://2ZGbzLs5.skscy.cn
http://oO7zKyAK.skscy.cn
http://Xmy6gWk8.skscy.cn
http://msrKhcWD.skscy.cn
http://OuOgSaiR.skscy.cn
http://www.dtcms.com/a/379764.html

相关文章:

  • 基于librdkafa C++客户端生产者发送数据失败问题处理#2
  • Maya绑定:渲染编辑器Hypershade简单使用,给小球添加材质纹理
  • 前端基础 —— A / HTML
  • 线性代数 | 行列式与矩阵区别
  • Redis 核心数据结构:String 类型深度解析与 C++ 实战
  • 【Linux】面试常考!Linux 进程核心考点:写时拷贝优化原理 + 进程等待实战,一篇理清进程一生
  • 根据当前门店经纬度,求出1km内的门店
  • java类冲突
  • 线上的Python服务如何部署?
  • ​​Cinema 4D 2026 核心亮点:AI智能驱动 + 无缝实时渲染​
  • 【Pywinauto库】10.7 pywinauto.controls.uia_controls控件
  • Next.js 字体优化:使用 `next/font` 告别布局偏移和性能瓶颈
  • 腾讯滑块---Js逆向酷狗音乐登入
  • 机器学习算法概述
  • zzz‘sJavaweb知识点总结
  • 【STL源码剖析】二叉世界的平衡:从BST 到 AVL-tree 和 RB-tree 的插入逻辑
  • Altium Designer使用精通教程 第四章(PCB封装库绘制)
  • 基于多模态与主动学习的车船飞机图像识别系统研究与应用技术方案
  • cesium的3dtiles模型矫正工具
  • Win7环境中离线安装Visual Studio 2017的相关问题
  • 解决 Typora 0.11.18 版本过期问题
  • 基于R语言机器学习方法在生态经济学领域中的实践技术应用;十大原理、熵权法、随机森林、神经网络、因果推断全解析
  • 数据结构:并查集
  • Unity Addressable System 本地服务器功能验证
  • 用简单的日期类巩固C++类与对象基本知识
  • python+springboot+uniapp微信小程序题库系统 在线答题 题目分类 错题本管理 学习记录查询系统
  • DeepSeek v3.1和DeepSeek R1在编程风格方面的区别
  • kafka启动小脚本
  • AI-调查研究-76-具身智能 当机器人走进生活:具身智能对就业与社会结构的深远影响
  • 机器学习-机器学习模型简介