活动网页怎么做苏州搜索引擎排名优化商家
深度解读DeepSeek:开源周(Open Source Week)技术解读
深度解读DeepSeek:源码解读 DeepSeek-V3
深度解读DeepSeek:技术原理
深度解读DeepSeek:发展历程
文章目录
- 整体流程
- 模型初始化
- 模型前向传播
- MLA
- MoE
https://github.com/deepseek-ai/DeepSeek-V3
DeepSeek-V3属于Transformer Decoder模型,具体特点:
- 生成式任务:代码通过generate函数实现文本生成(输入提示,输出补全),这是Decoder的核心功能。
- 自回归生成:逐token预测,Decoder通过掩码机制(仅关注左侧上下文)实现这一点。
整体流程
main入口:
def main(ckpt_path: str,config: str,input_file: str = "",interactive: bool = True,max_new_tokens: int = 100,temperature: float = 1.0,
) -> None:"""主函数:加载模型并执行交互式或批量文本生成Args:ckpt_path (str): 模型检查点目录路径config (str): 模型配置文件路径input_file (str, optional): 包含输入提示的文件路径。默认为空interactive (bool, optional): 是否启用交互模式。默认为Truemax_new_tokens (int, optional): 生成的最大新token数。默认为100temperature (float, optional): 采样温度参数。默认为1.0"""# ==================== 分布式初始化 ====================# 获取分布式训练参数(多GPU/多节点)world_size = int(os.getenv("WORLD_SIZE", "1")) # 总进程数rank = int(os.getenv("RANK", "0")) # 当前进程的全局排名local_rank = int(os.getenv("LOCAL_RANK", "0")) # 当前节点的本地排名# 初始化分布式进程组(NCCL后端)if world_size > 1:dist.init_process_group("nccl")# 仅在rank 0进程允许打印输出(避免多GPU重复打印)global printif rank != 0:print = lambda *_, **__: None # 禁用非主进程的打印# ==================== 硬件设置 ====================torch.cuda.set_device(local_rank) # 设置当前使用的GPU设备torch.set_default_dtype(torch.bfloat16) # 设置默认张量精度为bfloat16torch.set_num_threads(8) # 限制CPU线程数以优化性能torch.manual_seed(965) # 设置随机种子保证可重复性# ==================== 模型加载 ====================# 从配置文件加载模型参数with open(config) as f:args = ModelArgs(**json.load(f)) # 解析JSON配置文件为模型参数对象print(args) # 打印模型配置参数# 在CUDA设备上初始化模型with torch.device("cuda"):model = Transformer(args) # 实例化Transformer模型# 加载分词器(自动检测格式)tokenizer = AutoTokenizer.from_pretrained(ckpt_path)# 测试生成功能(生成2个token的示例)tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])# 加载分片模型权重(分布式训练时每个进程加载对应分片)load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))# ==================== 交互模式处理 ====================if interactive:messages = [] # 保存对话历史的列表while True:# 分布式环境下的输入处理if world_size == 1: # 单进程模式直接获取输入prompt = input(">>> ")elif rank == 0: # 分布式模式下只有rank 0进程接收输入prompt = input(">>> ")objects = [prompt]dist.broadcast_object_list(objects, 0) # 广播输入到其他进程else:objects = [None]dist.broadcast_object_list(objects, 0) # 接收来自rank 0的输入prompt = objects[0]# 处理特殊命令if prompt == "/exit": # 退出指令breakelif prompt == "/clear": # 清空对话历史messages.clear()continue# 添加用户消息到对话历史messages.append({"role": "user", "content": prompt})# 格式化对话历史为模型输入格式(添加特殊标记)prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True # 添加生成提示(如<|assistant|>))# 生成回复token(自回归生成)completion_tokens = generate(model, [prompt_tokens], max_new_tokens, # 最大生成长度tokenizer.eos_token_id, # 终止符IDtemperature # 温度参数控制随机性)# 解码生成的token为文本completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True # 跳过特殊标记(如<|endoftext|>))print(completion) # 输出生成的回复# 添加助手回复到对话历史messages.append({"role": "assistant", "content": completion})# ==================== 批量模式处理 ====================else:# 从文件读取所有提示with open(input_file) as f:prompts = [line.strip() for line in f.readlines()]# 验证批量大小不超过模型限制assert len(prompts) <= args.max_batch_size, \f"提示数量超过最大批量大小 ({args.max_batch_size})"# 为每个提示格式化输入(单轮对话)prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]# 批量生成回复completion_tokens = generate(model, prompt_tokens, max_new_tokens,tokenizer.eos_token_id,temperature)# 批量解码生成的tokencompletions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)# 打印所有提示和对应的生成结果for prompt, completion in zip(prompts, completio # 注意:此处代码不完整,应为completions
其中,生成回复token(自回归生成)的generate
函数核心逻辑如下:
1,通过模型前向传播model.forward
,获取当前位置的预测logits
2,通过贪心搜索获取概率最大的token = logits.argmax(dim=-1)
3,继续预测下个token,循环直到结束
def generate(model: Transformer,prompt_tokens: List[List[int]],max_new_tokens: int,eos_id: int,temperature: float = 1.0
) -> List[List[int]]:......# 自回归生成循环(逐个位置生成)for cur_pos in range(min(prompt_lens), total_len):# 获取当前位置的预测logits(仅处理新生成的部分)logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)# 根据温度选择采样策略if temperature > 0:next_token = sample(logits, temperature) # 带温度采样else:next_token = logits.argmax(dim=-1) # 贪心搜索# 保留原始prompt内容(当cur_pos在prompt范围内时使用原始token)next_token = torch.where(prompt_mask[:, cur_pos],tokens[:, cur_pos], # 原始prompt部分保持原样next_token # 生成部分使用新token)# 更新当前position的tokentokens[:, cur_pos] = next_token# 更新完成标志(非prompt位置且生成终止符时标记完成)finished |= torch.logical_and(~prompt_mask[:, cur_pos], # 仅关注生成部分next_token == eos_id # 检测终止符)prev_pos = cur_pos # 更新处理位置# 提前终止条件:所有序列都已完成if finished.all():break
模型初始化
# 定义Transformer模型类,继承自PyTorch的nn.Module
class Transformer(nn.Module):"""Transformer模型架构,包含:- 带位置编码的词嵌入- 多层Transformer块堆叠- 输出投影层Attributes:embed : 将输入token映射为向量的嵌入层layers : 包含多个Transformer块的模块列表norm : 最终输出的归一化层head : 将特征向量映射到词汇表空间的输出层freqs_cis: 预计算的旋转位置编码复数频率张量"""def __init__(self, args):super().__init__() # 调用父类nn.Module的初始化方法# Token嵌入层(模型并行实现)# args.vocab_size: 词汇表大小(102400)# args.dim : 嵌入维度(2048)self.embed = ParallelEmbedding(args.vocab_size, args.dim)# 创建Transformer块堆叠结构# args.n_layers: Transformer总层数(27)self.layers = torch.nn.ModuleList() # 用于存储多个可训练模块的容器for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args)) # 添加自定义的Transformer块(Block)# 最终输出归一化层(RMSNorm替代传统LayerNorm,减少约30%的计算操作)# RMSNorm(均方根归一化)是LayerNorm的变体,计算更高效,仅对输入向量的均方根值归一化,不涉及均值# args.dim: 输入特征维度(与嵌入维度一致2048)self.norm = RMSNorm(args.dim)# 输出投影层(模型并行线性层)# 将特征向量映射回词汇表空间生成logits# args.dim → args.vocab_size(102400)self.head = ColumnParallelLinear(args.dim, args.vocab_size,dtype=torch.get_default_dtype() # 保持与模型其他部分相同的数据类型)# 注册并预计算旋转位置编码的复数频率# precompute_freqs_cis: 根据dim和max_seq_len生成复数频率矩阵# persistent=False: 不保存到模型文件,加载时重新计算self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
每一层的Block:
# 定义Transformer的基础模块Block,继承自PyTorch的nn.Module
class Block(nn.Module):"""Transformer块结构,整合注意力机制和前馈网络Attributes:attn : Attention layer,自注意力计算层(MLA)ffn : Feed-forward network,前馈网络(常规MLP或混合专家MoE结构)attn_norm: Layer normalization for attention.,注意力层后的归一化模块ffn_norm : Layer normalization for feed-forward network,前馈层后的归一化模块"""def __init__(self, layer_id, args):super().__init__() # 初始化注意力层(MLA)self.attn = MLA(args) # 动态选择前馈网络类型(基于当前层ID)# 前1层使用普通MLP(args.inter_dim: 10944,MLP中间层维度),后续层使用MoE结构# MoE: 混合专家系统(如每层有多个专家网络+门控路由)self.ffn = (MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) )# 初始化注意力后的归一化层self.attn_norm = RMSNorm(args.dim) # 使用RMSNorm替代LayerNorm# 初始化前馈后的归一化层self.ffn_norm = RMSNorm(args.dim) # 同上
总结主要方法如下:
-
embed : 将输入token映射为向量的嵌入层
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
分布式词嵌入层,将词表拆分到多个GPU,降低单个GPU的显存占用,支持超大词表。
-
freqs_cis: 预计算的旋转位置编码复数频率张量
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
旋转位置编码(Rotary Position Embedding, RoPE),通过复数旋转注入位置信息,提升注意力机制的位置感知能力
-
layers : 包含多个Transformer块的模块列表,每个Block代表1个Transformer层(包含自注意力 + FFN)
self.attn = MLA(args) self.ffn = (MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) )
注意力层:采用多头潜在注意力(Multi-Head Latent Attention,MLA)
前馈网络层:前1层使用普通MLP,后续层使用混合专家模型(Mixture of Experts,MoE) -
norm : 归一化层
self.attn_norm = RMSNorm(args.dim) # 注意力归一化层 self.ffn_norm = RMSNorm(args.dim) # 前馈网络归一化层 self.norm = RMSNorm(args.dim) # 最终的归一化层
都采用均方根归一化(Root Mean Square Normalization),相比LayerNorm减少计算量,提升训练速度。
-
head : 将特征向量映射到词汇表空间的输出层
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
输出头(列并行线性层,适用于分布式矩阵乘法):列并行输出头,分布式计算更高效
模型前向传播
前向传播forward过程:
输入tokens → embed → 添加位置编码 → 逐层Block处理 → norm → head → 输出logits
class Transformer(nn.Module):def forward(self, tokens: torch.Tensor, start_pos: int = 0):"""Forward pass for the Transformer model.Args:tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.Returns:torch.Tensor: Logits tensor of shape (batch_size, vocab_size)."""seqlen = tokens.size(1)# 1,输入处理:Token序列通过embed层转换为向量;结合预计算的freqs_cis应用旋转位置编码(在Block内部实现)h = self.embed(tokens)freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]# 2,特征提取:向量经过多个layers层Block处理,每层包含自注意力、前馈网络等操作 mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 3,输出生成:最终输出经RMSNorm归一化后,通过head投影层生成logits h = self.norm(h)[:, -1]logits = self.head(h)if world_size > 1:all_logits = [torch.empty_like(logits) for _ in range(world_size)]dist.all_gather(all_logits, logits)logits = torch.cat(all_logits, dim=-1)return logits
其中,逐层layer进行Block处理:
class Block(nn.Module):def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:"""Transformer块的前向传播过程Args:x : 输入张量,形状为 [批大小, 序列长度, 特征维度]start_pos : 当前处理的起始位置(用于流式生成时定位缓存位置)freqs_cis : 预计算旋转位置编码的复数频率张量mask : 注意力掩码,形状为 [序列长度, 序列长度] 或 [批大小, 序列长度]Returns:torch.Tensor: 输出张量,形状与输入x相同"""# 前向传播方法: 输入x → 带归一化(RMSNorm)的注意力(MLA)残差连接 → 带归一化(RMSNorm)的前馈(MoE)残差连接 → 输出# 此处采用Pre-Norm(先归一化再进入子层),相比Post-Norm更易于训练深层模型x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)x = x + self.ffn(self.ffn_norm(x))return x
此步骤对应下面DeepSeek V3技术架构图,核心即MLA和MoE。
参数展开后完整拓扑如下:
Transformer( # Transformer模型主容器(embed): ParallelEmbedding() # 并行词嵌入层,分布式处理输入token到向量的映射(layers): ModuleList( # 堆叠的Transformer层(0): Block( # 第1个Transformer块(标准结构)(attn_norm): RMSNorm() # 注意力子层的预归一化(attn): MLA( # 改进的多头线性注意力层(可能含并行化)(wq): ColumnParallelLinear() # 列并行线性层,分割查询矩阵Q的权重(wkv_a): Linear() # 第一阶段线性变换,生成注意力中间表示(kv_norm): RMSNorm() # 对Key/Value进行均方根归一化(无均值中心化)(wkv_b): ColumnParallelLinear() # 列并行线性层,生成最终Key/Value(wo): RowParallelLinear() # 行并行线性层,合并注意力头输出)(ffn_norm): RMSNorm() # FFN子层的预归一化(ffn): MLP( # 标准门控前馈网络(w1): ColumnParallelLinear() # 门控列并行层(如GLU的g(x))(w2): RowParallelLinear() # 行并行层,主变换(如GLU的f(x)) (w3): ColumnParallelLinear() # 另一门控列并行层(可能用于增强非线性)))(1): Block( # 第2个Transformer块(含MoE)(attn_norm): RMSNorm() # 同上(attn): MLA(...) # 同上,注意力结构相同(ffn_norm): RMSNorm() # 同上(ffn): MoE( # 混合专家层(稀疏激活)(gate): Gate() # 路由门控,计算token到专家的分配权重(experts): ModuleList( # 专家集合(独立处理不同模式)(0-63): 64 x Expert( # 64个独立专家(如:每专家2-4层MLP)(w1): Linear() # 专家特定门控层(无并行)(w2): Linear() # 专家特定主变换(w3): Linear() # 专家特定补充门控))(shared_experts): MLP( # 共享专家(所有token必经过)(w1): ColumnParallelLinear() # 并行门控层(w2): RowParallelLinear() # 并行主变换(w3): ColumnParallelLinear() # 并行补充门控))))(norm): RMSNorm() # 最终输出归一化层(head): ColumnParallelLinear() # 并行输出头(词表投影,分布式计算logits)
)
MLA
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128class MLA(nn.Module):def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):"""多头注意力层(Multi-Headed Attention Layer, MLA)的前向传播过程。参数:x (torch.Tensor): 输入张量,形状为 (批次大小, 序列长度, 特征维度)。start_pos (int): 序列中的起始位置,用于缓存机制(如KV缓存)。freqs_cis (torch.Tensor): 预计算的旋转位置编码复数值(RoPE)。mask (Optional[torch.Tensor]): 注意力掩码张量,用于屏蔽无效位置。返回:torch.Tensor: 输出张量,形状与输入相同。"""bsz, seqlen, _ = x.size()end_pos = start_pos + seqlen# ================== 查询(Query)路径处理 ==================# 根据LoRA配置选择不同的查询生成方式(标准线性层或LoRA低秩分解)if self.q_lora_rank == 0:q = self.wq(x) # 全秩投影else:q = self.wq_b(self.q_norm(self.wq_a(x))) # LoRA低秩投影:先降维再升维# 将查询张量重塑为多头结构 [批次, 序列长度, 头数, 头维度]q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 分割查询为不使用位置编码(q_nope)和使用旋转位置编码(q_pe)的部分q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# 对需要位置编码的部分应用旋转位置编码(RoPE)q_pe = apply_rotary_emb(q_pe, freqs_cis)# ================== 键值(Key-Value)路径处理 ==================# 生成初始键值张量(LoRA降维处理)kv = self.wkv_a(x)# 分割键值张量为低秩部分和需要位置编码的键部分kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 对键的位置编码部分应用旋转位置编码(需增加头维度)k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)# ================== 注意力分数计算分支 ==================# Naive实现:传统注意力计算if attn_impl == "naive":q = torch.cat([q_nope, q_pe], dim=-1)kv = self.wkv_b(self.kv_norm(kv))kv = kv.view(...)k_nope, v = torch.split(kv, ...)k = torch.cat([k_nope, k_pe.expand(...)], dim=-1)# 更新缓存self.k_cache[:bsz, start_pos:end_pos] = kself.v_cache[:bsz, start_pos:end_pos] = v# 计算注意力分数scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale# 优化实现:分离计算路径else:wkv_b = ... # 可能的反量化操作q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 更新组合缓存self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)# 分数计算(分两部分)scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale# Masking与Softmaxif mask is not None:scores += mask.unsqueeze(1)scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)# 聚合值向量if attn_impl == "naive":# 传统路径:注意力权重与历史值缓存相乘x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])else:# 优化路径:分步计算降低显存占用x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])# ================== 输出投影 ==================x = self.wo(x.flatten(2)) # 合并多头输出并通过最终线性层return x
这里重点解释注意力分数计算(分两种实现):
(a) Naive实现
# 传统实现路径:合并查询的nope和rope部分q = torch.cat([q_nope, q_pe], dim=-1)# 对键值进行LoRA升维处理并重塑形状kv = self.wkv_b(self.kv_norm(kv))kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)# 分割键值张量为键的nope部分和值部分k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 合并键的nope部分和位置编码部分(扩展维度对齐头数)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)# 更新KV缓存(用于自回归生成)self.k_cache[:bsz, start_pos:end_pos] = kself.v_cache[:bsz, start_pos:end_pos] = v# 计算注意力分数:Q与历史K的点积(爱因斯坦求和约定)scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
b: 批次大小(batch) 8
s: 当前序列长度(query) 64
t: 总序列长度(key含历史) 1024
h: 注意力头数 16
d: 头维度(qk_head_dim) 128
然后进行掩码和softmax归一处理后,将注意力权重与历史值缓存相乘,合并多头输出并通过最终线性层。
实现如下完整数学公式:
(b) 优化实现
# 优化实现路径:处理权重量化(FP8/INT8)wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)# 对nope部分的查询进行低秩投影q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 更新键值缓存和位置编码缓存self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)# 分别计算nope路径和rope路径的注意力分数scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
然后进行掩码和softmax归一处理后,分步输出计算:
最后合并多头输出并通过最终线性层。
需要说明的是:在优化路径的注意力计算中,两次使用了 self.kv_cache, 是因为该缓存同时承担了 键(Key)的低秩投影 和 值(Value)的低秩投影 的双重角色。通过键值联合投影,将矩阵输入同时映射到键和值的联合低维空间,这一设计允许 kv_cache 同时编码键和值的信息。
第一次使用 kv_cache:作为键的低秩投影,用于键的升维,计算注意力分数。
第二次使用 kv_cache:作为值的低秩投影,用于值的升维,生成上下文向量。
两种方案对比:
- 传统方法:存储完整的键 K 和值 V,形状为
[batch, seq, heads, dim]
。显存占用较高,但无需额外拆分。 - MLA优化方法:通过低秩缓存 + 位置编码分离,用额外计算换显存效率
- kv_cache:存储低秩键值投影(kv_lora_rank 维度),减少显存占用, 形状为
[batch, seq, c]
。 - pe_cache:单独缓存位置编码部分(k_pe),避免重复计算 RoPE, 形状为
[batch, seq, c]
。
- kv_cache:存储低秩键值投影(kv_lora_rank 维度),减少显存占用, 形状为
批次大小b,序列长度s,头维度d,低秩维度c,那么:
MLA的总显存/传统显存 = (b*s*2*c) / (b*s*h*d) = 2c / hd
假设 h=16,d=128,c=64,节省约 93%
MoE
class MoE(nn.Module):"""混合专家(Mixture-of-Experts)模块,支持分布式训练。Attributes:dim (int): 输入特征的维度n_routed_experts (int): 总专家数量n_local_experts (int): 当前GPU本地管理的专家数量(分布式场景)n_activated_experts (int): 每个输入激活的专家数量(top-k)gate (nn.Module): 门控网络,决定输入分配给哪些专家experts (nn.ModuleList): 路由专家列表(部分为None,分布式时只实例化本地专家)shared_experts (nn.Module): 共享专家网络,处理所有输入"""def __init__(self, args: ModelArgs):super().__init__()self.dim = args.dim# 分布式校验:确保总专家数能被GPU总数整除assert args.n_routed_experts % world_size == 0, f"专家数必须能被GPU数整除 (当前world_size={world_size})"# 专家数量配置self.n_routed_experts = args.n_routed_experts # 总专家数self.n_local_experts = args.n_routed_experts // world_size # 当前GPU管理的专家数self.n_activated_experts = args.n_activated_experts # 每个输入激活的专家数# 分布式专家索引范围(将专家分配到不同GPU)self.experts_start_idx = rank * self.n_local_experts # 当前GPU的起始专家编号self.experts_end_idx = self.experts_start_idx + self.n_local_experts # 结束专家编号# 门控网络(通常为线性层+softmax)self.gate = Gate(args)# 创建专家列表(仅在本地GPU实例化对应专家,其他位置为None节省内存)self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])# 共享专家(所有输入都会经过该模块,与路由专家互补)self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:# 展平输入:[batch, seq_len, dim] => [batch*seq_len, dim]shape = x.size()x = x.view(-1, self.dim)# 门控计算:获取top-k专家权重和索引weights, indices = self.gate(x) # weights形状 [N, k], indices形状 [N, k]# 初始化输出张量y = torch.zeros_like(x)# 统计每个专家被选中的次数(用于跳过未激活的专家)counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()# 仅处理当前GPU管理的专家范围for i in range(self.experts_start_idx, self.experts_end_idx):if counts[i] == 0: # 该专家未被任何输入选中continueexpert = self.experts[i] # 获取本地专家实例idx, top = torch.where(indices == i) # 找出选择该专家的输入位置# 加权求和:y[选中的输入] += 专家输出 * 对应权重y[idx] += expert(x[idx]) * weights[idx, top, None] # None用于广播维度# 共享专家处理所有输入z = self.shared_experts(x)# 分布式通信:聚合所有GPU上的专家输出if world_size > 1:dist.all_reduce(y) # 使用AllReduce同步各GPU计算结果# 合并路由专家和共享专家的结果,并恢复原始形状return (y + z).view(shape)
关键机制解析:
1, 分布式专家分配:
- n_routed_experts: 总专家数(需为GPU数的整数倍)
- n_local_experts: 当前GPU管理的专家数(=总专家数 / GPU数)
- experts列表: 只实例化当前GPU负责的专家,其他位置为None,极大减少显存占用
2. 双路径计算:
-
路由专家路径
通过门控动态选择top-k专家,返回每行top-k值得权重weights和专家索引indices。
门控计算示例:# 假设某输入的门控输出为[0.1, 0.5, 0.3, 0.1](softmax后每个专家的概率),如果topk=2: weights = [0.5, 0.3] indices = [1, 2] # 对应的专家编号
典型计算流程:
A[输入x] --> B(门控网络)B --> C{选择top-k专家}C --> D[专家1计算]C --> E[专家2计算]D --> F(加权求和)E --> F
-
共享专家路径
所有输入均经过shared_experts,作用包括:- 提供基础特征变换能力
- 缓解专家路由不均衡问题
- 增强模型鲁棒性(即使某些路由专家失效仍有保底处理)
3. 通信优化:
all_reduce同步:确保各GPU上的部分结果y能全局聚合
设计特点总结:
-
显存高效
通过分布式专家分配,单个GPU只需保存部分专家参数。假设总专家数=64,GPU数=8,则每个GPU仅维护8个专家,显存占用减少为1/8。 -
计算负载均衡
torch.bincount统计专家使用次数,跳过未激活专家,避免无效计算。 -
灵活扩展
调整n_routed_experts和n_activated_experts可平衡模型容量与计算量
共享专家提供保底处理能力,提升模型稳定性