当前位置: 首页 > news >正文

深度解读DeepSeek:源码解读 DeepSeek-V3

深度解读DeepSeek:开源周(Open Source Week)技术解读
深度解读DeepSeek:源码解读 DeepSeek-V3
深度解读DeepSeek:技术原理
深度解读DeepSeek:发展历程

文章目录

  • 整体流程
  • 模型初始化
  • 模型前向传播
  • MLA
  • MoE

https://github.com/deepseek-ai/DeepSeek-V3

DeepSeek-V3属于Transformer Decoder模型,具体特点:

  • 生成式任务:代码通过generate函数实现文本生成(输入提示,输出补全),这是Decoder的核心功能。
  • 自回归生成:逐token预测,Decoder通过掩码机制(仅关注左侧上下文)实现这一点。

整体流程

main入口:

def main(
    ckpt_path: str,
    config: str,
    input_file: str = "",
    interactive: bool = True,
    max_new_tokens: int = 100,
    temperature: float = 1.0,
) -> None:
    """
    主函数:加载模型并执行交互式或批量文本生成
    
    Args:
        ckpt_path (str): 模型检查点目录路径
        config (str): 模型配置文件路径
        input_file (str, optional): 包含输入提示的文件路径。默认为空
        interactive (bool, optional): 是否启用交互模式。默认为True
        max_new_tokens (int, optional): 生成的最大新token数。默认为100
        temperature (float, optional): 采样温度参数。默认为1.0
    """
    # ==================== 分布式初始化 ====================
    # 获取分布式训练参数(多GPU/多节点)
    world_size = int(os.getenv("WORLD_SIZE", "1"))  # 总进程数
    rank = int(os.getenv("RANK", "0"))              # 当前进程的全局排名
    local_rank = int(os.getenv("LOCAL_RANK", "0"))  # 当前节点的本地排名

    # 初始化分布式进程组(NCCL后端)
    if world_size > 1:
        dist.init_process_group("nccl")

    # 仅在rank 0进程允许打印输出(避免多GPU重复打印)
    global print
    if rank != 0:
        print = lambda *_, **__: None  # 禁用非主进程的打印

    # ==================== 硬件设置 ====================
    torch.cuda.set_device(local_rank)       # 设置当前使用的GPU设备
    torch.set_default_dtype(torch.bfloat16) # 设置默认张量精度为bfloat16
    torch.set_num_threads(8)                # 限制CPU线程数以优化性能
    torch.manual_seed(965)                  # 设置随机种子保证可重复性

    # ==================== 模型加载 ====================
    # 从配置文件加载模型参数
    with open(config) as f:
        args = ModelArgs(**json.load(f))  # 解析JSON配置文件为模型参数对象
    print(args)  # 打印模型配置参数

    # 在CUDA设备上初始化模型
    with torch.device("cuda"):
        model = Transformer(args)  # 实例化Transformer模型
        
        # 加载分词器(自动检测格式)
        tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
        
        # 测试生成功能(生成2个token的示例)
        tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
        
        # 加载分片模型权重(分布式训练时每个进程加载对应分片)
        load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))

    # ==================== 交互模式处理 ====================
    if interactive:
        messages = []  # 保存对话历史的列表
        while True:
            # 分布式环境下的输入处理
            if world_size == 1:  # 单进程模式直接获取输入
                prompt = input(">>> ")
            elif rank == 0:      # 分布式模式下只有rank 0进程接收输入
                prompt = input(">>> ")
                objects = [prompt]
                dist.broadcast_object_list(objects, 0)  # 广播输入到其他进程
            else:
                objects = [None]
                dist.broadcast_object_list(objects, 0)  # 接收来自rank 0的输入
                prompt = objects[0]

            # 处理特殊命令
            if prompt == "/exit":    # 退出指令
                break
            elif prompt == "/clear": # 清空对话历史
                messages.clear()
                continue

            # 添加用户消息到对话历史
            messages.append({"role": "user", "content": prompt})
            
            # 格式化对话历史为模型输入格式(添加特殊标记)
            prompt_tokens = tokenizer.apply_chat_template(
                messages, 
                add_generation_prompt=True  # 添加生成提示(如<|assistant|>)
            )
            
            # 生成回复token(自回归生成)
            completion_tokens = generate(
                model, 
                [prompt_tokens], 
                max_new_tokens,          # 最大生成长度
                tokenizer.eos_token_id, # 终止符ID
                temperature             # 温度参数控制随机性
            )
            
            # 解码生成的token为文本
            completion = tokenizer.decode(
                completion_tokens[0], 
                skip_special_tokens=True  # 跳过特殊标记(如<|endoftext|>)
            )
            
            print(completion)  # 输出生成的回复
            
            # 添加助手回复到对话历史
            messages.append({"role": "assistant", "content": completion})

    # ==================== 批量模式处理 ====================
    else:
        # 从文件读取所有提示
        with open(input_file) as f:
            prompts = [line.strip() for line in f.readlines()]
        
        # 验证批量大小不超过模型限制
        assert len(prompts) <= args.max_batch_size, \
            f"提示数量超过最大批量大小 ({args.max_batch_size})"
        
        # 为每个提示格式化输入(单轮对话)
        prompt_tokens = [
            tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}], 
                add_generation_prompt=True
            ) for prompt in prompts
        ]
        
        # 批量生成回复
        completion_tokens = generate(
            model, 
            prompt_tokens, 
            max_new_tokens,
            tokenizer.eos_token_id,
            temperature
        )
        
        # 批量解码生成的token
        completions = tokenizer.batch_decode(
            completion_tokens, 
            skip_special_tokens=True
        )
        
        # 打印所有提示和对应的生成结果
        for prompt, completion in zip(prompts, completio  # 注意:此处代码不完整,应为completions

其中,生成回复token(自回归生成)的generate函数核心逻辑如下:

1,通过模型前向传播model.forward,获取当前位置的预测logits
2,通过贪心搜索获取概率最大的token = logits.argmax(dim=-1)
3,继续预测下个token,循环直到结束

def generate(
    model: Transformer,
    prompt_tokens: List[List[int]],
    max_new_tokens: int,
    eos_id: int,
    temperature: float = 1.0
) -> List[List[int]]:
	......
    # 自回归生成循环(逐个位置生成)
    for cur_pos in range(min(prompt_lens), total_len):
        # 获取当前位置的预测logits(仅处理新生成的部分)
        logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        
        # 根据温度选择采样策略
        if temperature > 0:
            next_token = sample(logits, temperature)  # 带温度采样
        else:
            next_token = logits.argmax(dim=-1)  # 贪心搜索
        
        # 保留原始prompt内容(当cur_pos在prompt范围内时使用原始token)
        next_token = torch.where(
            prompt_mask[:, cur_pos],
            tokens[:, cur_pos],  # 原始prompt部分保持原样
            next_token          # 生成部分使用新token
        )
        
        # 更新当前position的token
        tokens[:, cur_pos] = next_token
        
        # 更新完成标志(非prompt位置且生成终止符时标记完成)
        finished |= torch.logical_and(
            ~prompt_mask[:, cur_pos],  # 仅关注生成部分
            next_token == eos_id        # 检测终止符
        )
        
        prev_pos = cur_pos  # 更新处理位置
        
        # 提前终止条件:所有序列都已完成
        if finished.all():
            break

模型初始化

# 定义Transformer模型类,继承自PyTorch的nn.Module
class Transformer(nn.Module):
    """
    Transformer模型架构,包含:
    - 带位置编码的词嵌入
    - 多层Transformer块堆叠
    - 输出投影层
    
    Attributes:
        embed    : 将输入token映射为向量的嵌入层
        layers   : 包含多个Transformer块的模块列表
        norm     : 最终输出的归一化层
        head     : 将特征向量映射到词汇表空间的输出层
        freqs_cis: 预计算的旋转位置编码复数频率张量
    """

    def __init__(self, args):
        super().__init__()  # 调用父类nn.Module的初始化方法
        
        # Token嵌入层(模型并行实现)
        # args.vocab_size: 词汇表大小(102400)
        # args.dim      : 嵌入维度(2048)
        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
        
        # 创建Transformer块堆叠结构
        # args.n_layers: Transformer总层数(27)
        self.layers = torch.nn.ModuleList()  # 用于存储多个可训练模块的容器
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))  # 添加自定义的Transformer块(Block)

        # 最终输出归一化层(RMSNorm替代传统LayerNorm,减少约30%的计算操作)
        # RMSNorm(均方根归一化)是LayerNorm的变体,计算更高效,仅对输入向量的均方根值归一化,不涉及均值
        # args.dim: 输入特征维度(与嵌入维度一致2048)
        self.norm = RMSNorm(args.dim)
        
        # 输出投影层(模型并行线性层)
        # 将特征向量映射回词汇表空间生成logits
        # args.dim → args.vocab_size(102400)
        self.head = ColumnParallelLinear(
            args.dim, 
            args.vocab_size,
            dtype=torch.get_default_dtype()  # 保持与模型其他部分相同的数据类型
        )
        
        # 注册并预计算旋转位置编码的复数频率
        # precompute_freqs_cis: 根据dim和max_seq_len生成复数频率矩阵
        # persistent=False: 不保存到模型文件,加载时重新计算
        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

每一层的Block:

# 定义Transformer的基础模块Block,继承自PyTorch的nn.Module
class Block(nn.Module):
    """
    Transformer块结构,整合注意力机制和前馈网络
    
    Attributes:
        attn     : Attention layer,自注意力计算层(MLA)
        ffn      : Feed-forward network,前馈网络(常规MLP或混合专家MoE结构)
        attn_norm: Layer normalization for attention.,注意力层后的归一化模块
        ffn_norm : Layer normalization for feed-forward network,前馈层后的归一化模块
    """

    def __init__(self, layer_id, args):
        super().__init__() 
        
        # 初始化注意力层(MLA)
        self.attn = MLA(args) 
        
        # 动态选择前馈网络类型(基于当前层ID)
        # 前1层使用普通MLP(args.inter_dim: 10944,MLP中间层维度),后续层使用MoE结构
        # MoE: 混合专家系统(如每层有多个专家网络+门控路由)
        self.ffn = (
            MLP(args.dim, args.inter_dim)  
            if layer_id < args.n_dense_layers  
            else MoE(args)  
        )
        
        # 初始化注意力后的归一化层
        self.attn_norm = RMSNorm(args.dim)  # 使用RMSNorm替代LayerNorm
        
        # 初始化前馈后的归一化层
        self.ffn_norm = RMSNorm(args.dim)  # 同上

总结主要方法如下:

  • embed : 将输入token映射为向量的嵌入层

    self.embed = ParallelEmbedding(args.vocab_size, args.dim)
    

    分布式词嵌入层,将词表拆分到多个GPU,降低单个GPU的显存占用,支持超大词表。

  • freqs_cis: 预计算的旋转位置编码复数频率张量

    self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
    

    旋转位置编码(Rotary Position Embedding, RoPE),通过复数旋转注入位置信息,提升注意力机制的位置感知能力

  • layers : 包含多个Transformer块的模块列表,每个Block代表1个Transformer层(包含自注意力 + FFN)

    self.attn = MLA(args) 
    self.ffn = (
        MLP(args.dim, args.inter_dim)  
        if layer_id < args.n_dense_layers  
        else MoE(args)  
    )
    

    注意力层:采用多头潜在注意力(Multi-Head Latent Attention,MLA)
    前馈网络层:前1层使用普通MLP,后续层使用混合专家模型(Mixture of Experts,MoE)

  • norm : 归一化层

    self.attn_norm = RMSNorm(args.dim)    # 注意力归一化层
    self.ffn_norm = RMSNorm(args.dim)   # 前馈网络归一化层
    self.norm = RMSNorm(args.dim)    # 最终的归一化层
    

    都采用均方根归一化(Root Mean Square Normalization),相比LayerNorm减少计算量,提升训练速度。

  • head : 将特征向量映射到词汇表空间的输出层

    self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
    

    输出头(列并行线性层,适用于分布式矩阵乘法):列并行输出头,分布式计算更高效

模型前向传播

前向传播forward过程:
输入tokens → embed → 添加位置编码 → 逐层Block处理 → norm → head → 输出logits

class Transformer(nn.Module):
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        """
        Forward pass for the Transformer model.

        Args:
            tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
            start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.

        Returns:
            torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
        """
        seqlen = tokens.size(1)
        # 1,输入处理:Token序列通过embed层转换为向量;结合预计算的freqs_cis应用旋转位置编码(在Block内部实现)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
        # 2,特征提取:向量经过多个layers层Block处理,每层包含自注意力、前馈网络等操作    
        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        # 3,输出生成:最终输出经RMSNorm归一化后,通过head投影层生成logits  
        h = self.norm(h)[:, -1]
        logits = self.head(h)
        if world_size > 1:
            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
            dist.all_gather(all_logits, logits)
            logits = torch.cat(all_logits, dim=-1)
        return logits

其中,逐层layer进行Block处理:

class Block(nn.Module):
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
	    """
	    Transformer块的前向传播过程
	
	    Args:
	        x         : 输入张量,形状为 [批大小, 序列长度, 特征维度]
	        start_pos : 当前处理的起始位置(用于流式生成时定位缓存位置)
	        freqs_cis : 预计算旋转位置编码的复数频率张量
	        mask      : 注意力掩码,形状为 [序列长度, 序列长度] 或 [批大小, 序列长度]
	
	    Returns:
	        torch.Tensor: 输出张量,形状与输入x相同
	    """
        # 前向传播方法: 输入x → 带归一化(RMSNorm)的注意力(MLA)残差连接 → 带归一化(RMSNorm)的前馈(MoE)残差连接 → 输出
        # 此处采用Pre-Norm(先归一化再进入子层),相比Post-Norm更易于训练深层模型
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x

此步骤对应下面DeepSeek V3技术架构图,核心即MLA和MoE。
在这里插入图片描述

参数展开后完整拓扑如下:

Transformer(  # Transformer模型主容器
  (embed): ParallelEmbedding()  # 并行词嵌入层,分布式处理输入token到向量的映射
  (layers): ModuleList(  # 堆叠的Transformer层
    (0): Block(  # 第1个Transformer块(标准结构)
      (attn_norm): RMSNorm()  # 注意力子层的预归一化
      (attn): MLA(  # 改进的多头线性注意力层(可能含并行化)
        (wq): ColumnParallelLinear()  # 列并行线性层,分割查询矩阵Q的权重
        (wkv_a): Linear()  # 第一阶段线性变换,生成注意力中间表示
        (kv_norm): RMSNorm()  # 对Key/Value进行均方根归一化(无均值中心化)
        (wkv_b): ColumnParallelLinear()  # 列并行线性层,生成最终Key/Value
        (wo): RowParallelLinear()  # 行并行线性层,合并注意力头输出
      )
      (ffn_norm): RMSNorm()  # FFN子层的预归一化
      (ffn): MLP(  # 标准门控前馈网络
        (w1): ColumnParallelLinear()  # 门控列并行层(如GLU的g(x))
        (w2): RowParallelLinear()  # 行并行层,主变换(如GLU的f(x)) 
        (w3): ColumnParallelLinear()  # 另一门控列并行层(可能用于增强非线性)
      )
    )
    (1): Block(  # 第2个Transformer块(含MoE)
      (attn_norm): RMSNorm()  # 同上
      (attn): MLA(...)  # 同上,注意力结构相同
      (ffn_norm): RMSNorm()  # 同上
      (ffn): MoE(  # 混合专家层(稀疏激活)
        (gate): Gate()  # 路由门控,计算token到专家的分配权重
        (experts): ModuleList(  # 专家集合(独立处理不同模式)
          (0-63): 64 x Expert(  # 64个独立专家(如:每专家2-4层MLP)
            (w1): Linear()  # 专家特定门控层(无并行)
            (w2): Linear()  # 专家特定主变换
            (w3): Linear()  # 专家特定补充门控
          )
        )
        (shared_experts): MLP(  # 共享专家(所有token必经过)
          (w1): ColumnParallelLinear()  # 并行门控层
          (w2): RowParallelLinear()  # 并行主变换
          (w3): ColumnParallelLinear()  # 并行补充门控
        )
      )
    )
  )
  (norm): RMSNorm()  # 最终输出归一化层
  (head): ColumnParallelLinear()  # 并行输出头(词表投影,分布式计算logits)
)

MLA

# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128

class MLA(nn.Module):
  def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
     """
     多头注意力层(Multi-Headed Attention Layer, MLA)的前向传播过程。

     参数:
         x (torch.Tensor): 输入张量,形状为 (批次大小, 序列长度, 特征维度)。
         start_pos (int): 序列中的起始位置,用于缓存机制(如KV缓存)。
         freqs_cis (torch.Tensor): 预计算的旋转位置编码复数值(RoPE)。
         mask (Optional[torch.Tensor]): 注意力掩码张量,用于屏蔽无效位置。

     返回:
         torch.Tensor: 输出张量,形状与输入相同。
     """
    bsz, seqlen, _ = x.size()
    end_pos = start_pos + seqlen

    # ================== 查询(Query)路径处理 ==================
    # 根据LoRA配置选择不同的查询生成方式(标准线性层或LoRA低秩分解)
	if self.q_lora_rank == 0:
	    q = self.wq(x)  # 全秩投影
	else:
	    q = self.wq_b(self.q_norm(self.wq_a(x)))  # LoRA低秩投影:先降维再升维
    
    # 将查询张量重塑为多头结构 [批次, 序列长度, 头数, 头维度]
    q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
    # 分割查询为不使用位置编码(q_nope)和使用旋转位置编码(q_pe)的部分
    q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
    # 对需要位置编码的部分应用旋转位置编码(RoPE)
    q_pe = apply_rotary_emb(q_pe, freqs_cis)

    # ================== 键值(Key-Value)路径处理 ==================
    # 生成初始键值张量(LoRA降维处理)
    kv = self.wkv_a(x)
    # 分割键值张量为低秩部分和需要位置编码的键部分
    kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    # 对键的位置编码部分应用旋转位置编码(需增加头维度)
    k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

    # ================== 注意力分数计算分支 ==================
    # Naive实现:传统注意力计算
    if attn_impl == "naive":
        q = torch.cat([q_nope, q_pe], dim=-1)
        kv = self.wkv_b(self.kv_norm(kv))
        kv = kv.view(...)
        k_nope, v = torch.split(kv, ...)
        k = torch.cat([k_nope, k_pe.expand(...)], dim=-1)
        # 更新缓存
        self.k_cache[:bsz, start_pos:end_pos] = k
        self.v_cache[:bsz, start_pos:end_pos] = v
        # 计算注意力分数
        scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
    # 优化实现:分离计算路径
    else:
        wkv_b = ...  # 可能的反量化操作
        q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
        # 更新组合缓存
        self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
        self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
        # 分数计算(分两部分)
        scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                  torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

    # Masking与Softmax
    if mask is not None:
        scores += mask.unsqueeze(1)
    scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

    # 聚合值向量
    if attn_impl == "naive":
        # 传统路径:注意力权重与历史值缓存相乘
        x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
    else:
        # 优化路径:分步计算降低显存占用
        x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
        x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])

    # ================== 输出投影 ==================
    x = self.wo(x.flatten(2))  # 合并多头输出并通过最终线性层
    return x

在这里插入图片描述

这里重点解释注意力分数计算(分两种实现):

(a) Naive实现

 # 传统实现路径:合并查询的nope和rope部分
 q = torch.cat([q_nope, q_pe], dim=-1)
 # 对键值进行LoRA升维处理并重塑形状
 kv = self.wkv_b(self.kv_norm(kv))
 kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
 # 分割键值张量为键的nope部分和值部分
 k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
 # 合并键的nope部分和位置编码部分(扩展维度对齐头数)
 k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
 # 更新KV缓存(用于自回归生成)
 self.k_cache[:bsz, start_pos:end_pos] = k
 self.v_cache[:bsz, start_pos:end_pos] = v
 # 计算注意力分数:Q与历史K的点积(爱因斯坦求和约定)
 scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

在这里插入图片描述

b:	批次大小(batch)	8
s:	当前序列长度(query)	64
t:	总序列长度(key含历史)	1024
h:	注意力头数	16
d:	头维度(qk_head_dim)	128

然后进行掩码和softmax归一处理后,将注意力权重与历史值缓存相乘,合并多头输出并通过最终线性层。
实现如下完整数学公式:

在这里插入图片描述

(b) 优化实现

 # 优化实现路径:处理权重量化(FP8/INT8)
 wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
 wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
 # 对nope部分的查询进行低秩投影
 q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
 # 更新键值缓存和位置编码缓存
 self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
 self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
 # 分别计算nope路径和rope路径的注意力分数
 scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
           torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

在这里插入图片描述
然后进行掩码和softmax归一处理后,分步输出计算:

在这里插入图片描述

最后合并多头输出并通过最终线性层。

需要说明的是:在优化路径的注意力计算中,两次使用了 self.kv_cache, 是因为该缓存同时承担了 键(Key)的低秩投影 和 值(Value)的低秩投影 的双重角色。通过键值联合投影,将矩阵输入同时映射到键和值的联合低维空间,这一设计允许 kv_cache 同时编码键和值的信息。
第一次使用 kv_cache:作为键的低秩投影,用于键的升维,计算注意力分数。
第二次使用 kv_cache:作为值的低秩投影,用于值的升维,生成上下文向量。

两种方案对比:

  • 传统方法:存储完整的键 K 和值 V,形状为 [batch, seq, heads, dim]。显存占用较高,但无需额外拆分。
  • MLA优化方法:通过低秩缓存 + 位置编码分离,用额外计算换显存效率
    • kv_cache:存储低秩键值投影(kv_lora_rank 维度),减少显存占用, 形状为 [batch, seq, c]
    • pe_cache:单独缓存位置编码部分(k_pe),避免重复计算 RoPE, 形状为 [batch, seq, c]

批次大小b,序列长度s,头维度d,低秩维度c,那么:

MLA的总显存/传统显存  = (b*s*2*c) / (b*s*h*d) = 2c / hd

假设 h=16,d=128,c=64,节省约 93%

MoE

class MoE(nn.Module):
    """
    混合专家(Mixture-of-Experts)模块,支持分布式训练。

    Attributes:
        dim (int): 输入特征的维度
        n_routed_experts (int): 总专家数量
        n_local_experts (int): 当前GPU本地管理的专家数量(分布式场景)
        n_activated_experts (int): 每个输入激活的专家数量(top-k)
        gate (nn.Module): 门控网络,决定输入分配给哪些专家
        experts (nn.ModuleList): 路由专家列表(部分为None,分布式时只实例化本地专家)
        shared_experts (nn.Module): 共享专家网络,处理所有输入
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        # 分布式校验:确保总专家数能被GPU总数整除
        assert args.n_routed_experts % world_size == 0, f"专家数必须能被GPU数整除 (当前world_size={world_size})"
        
        # 专家数量配置
        self.n_routed_experts = args.n_routed_experts          # 总专家数
        self.n_local_experts = args.n_routed_experts // world_size  # 当前GPU管理的专家数
        self.n_activated_experts = args.n_activated_experts   # 每个输入激活的专家数
        
        # 分布式专家索引范围(将专家分配到不同GPU)
        self.experts_start_idx = rank * self.n_local_experts   # 当前GPU的起始专家编号
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts  # 结束专家编号
        
        # 门控网络(通常为线性层+softmax)
        self.gate = Gate(args)
        
        # 创建专家列表(仅在本地GPU实例化对应专家,其他位置为None节省内存)
        self.experts = nn.ModuleList([
            Expert(args.dim, args.moe_inter_dim) 
            if self.experts_start_idx <= i < self.experts_end_idx else None
            for i in range(self.n_routed_experts)
        ])
        
        # 共享专家(所有输入都会经过该模块,与路由专家互补)
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 展平输入:[batch, seq_len, dim] => [batch*seq_len, dim]
        shape = x.size()
        x = x.view(-1, self.dim)
        
        # 门控计算:获取top-k专家权重和索引
        weights, indices = self.gate(x)  # weights形状 [N, k], indices形状 [N, k]
        
        # 初始化输出张量
        y = torch.zeros_like(x)
        
        # 统计每个专家被选中的次数(用于跳过未激活的专家)
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        
        # 仅处理当前GPU管理的专家范围
        for i in range(self.experts_start_idx, self.experts_end_idx):
            if counts[i] == 0:  # 该专家未被任何输入选中
                continue
            
            expert = self.experts[i]  # 获取本地专家实例
            idx, top = torch.where(indices == i)  # 找出选择该专家的输入位置
            
            # 加权求和:y[选中的输入] += 专家输出 * 对应权重
            y[idx] += expert(x[idx]) * weights[idx, top, None]  # None用于广播维度
        
        # 共享专家处理所有输入
        z = self.shared_experts(x)
        
        # 分布式通信:聚合所有GPU上的专家输出
        if world_size > 1:
            dist.all_reduce(y)  # 使用AllReduce同步各GPU计算结果
        
        # 合并路由专家和共享专家的结果,并恢复原始形状
        return (y + z).view(shape)

关键机制解析:

1, 分布式专家分配:

  • n_routed_experts: 总专家数(需为GPU数的整数倍)
  • n_local_experts: 当前GPU管理的专家数(=总专家数 / GPU数)
  • experts列表: 只实例化当前GPU负责的专家,其他位置为None,极大减少显存占用

2. 双路径计算:

  • 路由专家路径
    通过门控动态选择top-k专家,返回每行top-k值得权重weights和专家索引indices。
    门控计算示例:

    # 假设某输入的门控输出为[0.1, 0.5, 0.3, 0.1](softmax后每个专家的概率),如果topk=2:
    weights = [0.5, 0.3]  
    indices = [1, 2]       # 对应的专家编号
    

    典型计算流程:

      A[输入x] --> B(门控网络)
      B --> C{选择top-k专家}
      C --> D[专家1计算]
      C --> E[专家2计算]
      D --> F(加权求和)
      E --> F
    
  • 共享专家路径
    所有输入均经过shared_experts,作用包括:

    • 提供基础特征变换能力
    • 缓解专家路由不均衡问题
    • 增强模型鲁棒性(即使某些路由专家失效仍有保底处理)

3. 通信优化:

all_reduce同步:确保各GPU上的部分结果y能全局聚合

设计特点总结:

  • 显存高效
    通过分布式专家分配,单个GPU只需保存部分专家参数。假设总专家数=64,GPU数=8,则每个GPU仅维护8个专家,显存占用减少为1/8。

  • 计算负载均衡
    torch.bincount统计专家使用次数,跳过未激活专家,避免无效计算。

  • 灵活扩展
    调整n_routed_experts和n_activated_experts可平衡模型容量与计算量
    共享专家提供保底处理能力,提升模型稳定性

相关文章:

  • 动态规划-基础
  • ESP8266 RTOS SDK 使用make命令编译出现Permission denied问题的解决方法
  • Ubuntu 14.10 Desktop (i386):经典 32 位操作系统的回顾与指南(附安装包)
  • 基于yolov11的防震锤缺陷检测系统python源码+pytorch模型+评估指标曲线+精美GUI界面
  • WSL 环境桥接与雷达通信配置笔记
  • APM 仿真遥控指南
  • 音频录制小妙招-自制工具-借助浏览器录一段单声道16000采样率wav格式音频
  • ARM架构薄记2——ARM学习架构抓手(以ARMv7为例子)
  • 元音辅音及其字母组合发音
  • 基于STM32进行FFT滤波
  • Python 常用内建模块-urllib
  • LINUX基础 [二] - 进程概念
  • 简单实用!百度AI + Raphael AI = 免费生图
  • CSS 中flex - grow、flex - shrink和flex - basis属性的含义及它们在弹性盒布局中的协同作用。
  • 以“无敏”理念守护婴童健康成长,Witsbb健敏思获京东健康“新锐突破奖”
  • [笔记.AI]多头自注意力机制(Multi-Head Attention)
  • C# 元组
  • 【图像生成之十八】Seedream 2.0
  • 计算机网络总结
  • OpenHarmony 开源硬件学习全指南:从入门到实战
  • 康子兴评《文明的追求》|野人脚印:鲁滨逊的恐惧与文明焦虑
  • 比尔·盖茨:未来20年通过盖茨基金会捐出几乎全部财富,2045年底基金会停止运营
  • 习近平同俄罗斯总统普京会谈
  • 奥利弗·斯通回顾越战50周年:我们不善于总结历史教训
  • 司法部:持续规范行政执法行为,加快制定行政执法监督条例
  • 外交部:中方和欧洲议会决定同步全面取消对相互交往的限制