导入必要的库
import torch
import torch. nn as nn
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import typing
from transformers. utils import TransformersKwargs
from transformers. modeling_flash_attention_utils import FlashAttentionKwargs
Unpack = typing. Unpack
旋转位置编码辅助函数
def rotate_half ( x) : x1 = x[ . . . , : x. shape[ - 1 ] // 2 ] x2 = x[ . . . , x. shape[ - 1 ] // 2 : ] return torch. cat( ( - x2, x1) , dim= - 1 ) def apply_rotary_pos_emb ( q, k, cos, sin, position_ids= None , unsqueeze_dim= 1 ) : cos = cos. unsqueeze( unsqueeze_dim) sin = sin. unsqueeze( unsqueeze_dim) q_embed = ( q * cos) + ( rotate_half( q) * sin) k_embed = ( k * cos) + ( rotate_half( k) * sin) return q_embed, k_embed
键值重复函数
def repeat_kv ( hidden_states: torch. Tensor, n_rep: int ) - > torch. Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states. shapeif n_rep == 1 : return hidden_stateshidden_states = hidden_states[ : , : , None , : , : ] . expand( batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states. reshape( batch, num_key_value_heads * n_rep, slen, head_dim)
Eager注意力前向传播函数
def eager_attention_forward ( module: nn. Module, query: torch. Tensor, key: torch. Tensor, value: torch. Tensor, attention_mask: Optional[ torch. Tensor] , scaling: float , dropout: float = 0.0 , ** kwargs: Unpack[ TransformersKwargs] ,
) : key_states = repeat_kv( key, module. num_key_value_groups) value_states = repeat_kv( value, module. num_key_value_groups) attn_weights = torch. matmul( query, key_states. transpose( 2 , 3 ) ) * scaling if attention_mask is not None : causal_mask = attention_mask[ : , : , : , : key_states. shape[ - 2 ] ] print ( "causal_mask:" , causal_mask. shape) attn_weights = attn_weights + causal_mask attn_weights = nn. functional. softmax( attn_weights, dim= - 1 , dtype= torch. float32) . to( query. dtype) attn_weights = nn. functional. dropout( attn_weights, p= dropout, training= module. training) print ( "attn_weights:" , attn_weights. shape) attn_output = torch. matmul( attn_weights, value_states) attn_output = attn_output. transpose( 1 , 2 ) . contiguous( ) return attn_output, attn_weights
RoPE位置编码实现
def default_rope_init ( config, device= None ) : """默认的RoPE初始化函数""" dim = config. head_dim if hasattr ( config, 'head_dim' ) else config. hidden_size inv_freq = 1.0 / ( config. rope_theta ** ( torch. arange( 0 , dim, 2 , dtype= torch. float32) / dim) ) print ( "inv_freq:" , inv_freq. shape) return inv_freq. to( device) , 1.0
ROPE_INIT_FUNCTIONS = { "default" : default_rope_init,
} class Qwen3MoeRotaryEmbedding ( nn. Module) : inv_freq: torch. Tensor def __init__ ( self, config, device= None ) : super ( ) . __init__( ) if hasattr ( config, "rope_scaling" ) and isinstance ( config. rope_scaling, dict ) : self. rope_type = config. rope_scaling. get( "rope_type" , config. rope_scaling. get( "type" ) ) else : self. rope_type = "default" self. max_seq_len_cached = config. max_position_embeddingsself. original_max_seq_len = config. max_position_embeddingsself. config = configself. rope_init_fn = ROPE_INIT_FUNCTIONS[ self. rope_type] inv_freq, self. attention_scaling = self. rope_init_fn( self. config, device) self. register_buffer( "inv_freq" , inv_freq, persistent= False ) self. original_inv_freq = self. inv_freq@torch. no_grad ( ) def forward ( self, x, position_ids) : inv_freq_expanded = self. inv_freq[ None , : , None ] . float ( ) . expand( position_ids. shape[ 0 ] , - 1 , 1 ) . to( x. device) position_ids_expanded = position_ids[ : , None , : ] . float ( ) device_type = x. device. type if isinstance ( x. device. type , str ) and x. device. type != "mps" else "cpu" with torch. autocast( device_type= device_type, enabled= False ) : freqs = ( inv_freq_expanded. float ( ) @ position_ids_expanded. float ( ) ) . transpose( 1 , 2 ) emb = torch. cat( ( freqs, freqs) , dim= - 1 ) cos = emb. cos( ) * self. attention_scalingsin = emb. sin( ) * self. attention_scalingreturn cos. to( dtype= x. dtype) , sin. to( dtype= x. dtype)
Qwen3Moe注意力机制实现
class Qwen3MoeAttention ( nn. Module) : """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__ ( self, config, layer_idx: int ) : super ( ) . __init__( ) self. config = configself. layer_idx = layer_idx self. head_dim = getattr ( config, "head_dim" , config. hidden_size // config. num_attention_heads) self. num_key_value_groups = config. num_attention_heads // config. num_key_value_heads self. scaling = self. head_dim** - 0.5 self. attention_dropout = config. attention_dropout self. is_causal = True self. q_proj = nn. Linear( config. hidden_size, config. num_attention_heads * self. head_dim, bias= config. attention_bias) self. k_proj = nn. Linear( config. hidden_size, config. num_key_value_heads * self. head_dim, bias= config. attention_bias) self. v_proj = nn. Linear( config. hidden_size, config. num_key_value_heads * self. head_dim, bias= config. attention_bias) self. o_proj = nn. Linear( config. num_attention_heads * self. head_dim, config. hidden_size, bias= config. attention_bias) self. q_norm = Qwen3MoeRMSNorm( self. head_dim, eps= config. rms_norm_eps) self. k_norm = Qwen3MoeRMSNorm( self. head_dim, eps= config. rms_norm_eps) self. sliding_window = getattr ( config, "sliding_window" , None ) def forward ( self, hidden_states: torch. Tensor, position_embeddings: tuple [ torch. Tensor, torch. Tensor] , attention_mask: Optional[ torch. Tensor] , past_key_values: Optional = None , cache_position: Optional[ torch. LongTensor] = None , ** kwargs: Unpack[ FlashAttentionKwargs] , ) - > tuple [ torch. Tensor, Optional[ torch. Tensor] ] : input_shape = hidden_states. shape[ : - 1 ] hidden_shape = ( * input_shape, - 1 , self. head_dim) query_states = self. q_norm( self. q_proj( hidden_states) . view( hidden_shape) ) . transpose( 1 , 2 ) key_states = self. k_norm( self. k_proj( hidden_states) . view( hidden_shape) ) . transpose( 1 , 2 ) value_states = self. v_proj( hidden_states) . view( hidden_shape) . transpose( 1 , 2 ) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin) print ( "query_states:" , query_states. shape) if past_key_values is not None : cache_kwargs = { "sin" : sin, "cos" : cos, "cache_position" : cache_position} key_states, value_states = past_key_values. update( key_states, value_states, self. layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self. config. _attn_implementation != "eager" : attention_interface = ALL_ATTENTION_FUNCTIONS[ self. config. _attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout= 0.0 if not self. training else self. attention_dropout, scaling= self. scaling, sliding_window= self. sliding_window, ** kwargs, ) attn_output = attn_output. reshape( * input_shape, - 1 ) . contiguous( ) attn_output = self. o_proj( attn_output) return attn_output, attn_weights
模拟配置和RMSNorm实现
@dataclass
class MockConfig : hidden_size: int = 512 num_attention_heads: int = 8 num_key_value_heads: int = 4 head_dim: int = 64 max_position_embeddings: int = 2048 rope_theta: float = 10000.0 rms_norm_eps: float = 1e - 6 attention_bias: bool = False attention_dropout: float = 0.0 _attn_implementation: str = "eager" class Qwen3MoeRMSNorm ( nn. Module) : def __init__ ( self, hidden_size, eps= 1e - 6 ) : super ( ) . __init__( ) self. weight = nn. Parameter( torch. ones( hidden_size) ) self. variance_epsilon = epsdef forward ( self, hidden_states) : input_dtype = hidden_states. dtypehidden_states = hidden_states. to( torch. float32) variance = hidden_states. pow ( 2 ) . mean( - 1 , keepdim= True ) hidden_states = hidden_states * torch. rsqrt( variance + self. variance_epsilon) return self. weight * hidden_states. to( input_dtype) def extra_repr ( self) : return f" { tuple ( self. weight. shape) } , eps= { self. variance_epsilon} "
主函数:测试代码
if __name__ == "__main__" : config = MockConfig( ) attention_layer = Qwen3MoeAttention( config, layer_idx= 0 ) rotary_emb = Qwen3MoeRotaryEmbedding( config) batch_size = 2 seq_length = 8 hidden_size = config. hidden_size hidden_states = torch. randn( batch_size, seq_length, hidden_size) position_ids = torch. arange( seq_length) . unsqueeze( 0 ) . expand( batch_size, - 1 ) cos, sin = rotary_emb( hidden_states, position_ids) print ( f"Position embeddings:" ) print ( f" - cos shape: { cos. shape} " ) print ( f" - sin shape: { sin. shape} " ) attention_mask = torch. tril( torch. ones( batch_size, 1 , seq_length, seq_length) ) attention_mask = ( 1.0 - attention_mask) * torch. finfo( torch. float32) . min attention_output, attention_weights = attention_layer( hidden_states= hidden_states, position_embeddings= ( cos, sin) , attention_mask= attention_mask ) print ( f"\nAttention results:" ) print ( f" - Input shape: { hidden_states. shape} " ) print ( f" - Output shape: { attention_output. shape} " ) print ( f" - Attention weights shape: { attention_weights. shape} " )
inv_freq: torch.Size([32])
Position embeddings:- cos shape: torch.Size([2, 8, 64])- sin shape: torch.Size([2, 8, 64])
query_states: torch.Size([2, 8, 8, 64])
causal_mask: torch.Size([2, 1, 8, 8])
attn_weights: torch.Size([2, 8, 8, 8])Attention results:- Input shape: torch.Size([2, 8, 512])- Output shape: torch.Size([2, 8, 512])- Attention weights shape: torch.Size([2, 8, 8, 8])