当前位置: 首页 > news >正文

门控MLP(Qwen3MLP)与稀疏混合专家(Qwen3MoeSparseMoeBlock)模块解析

Qwen3MLP

Qwen3MLP是基于门控机制的MLP模块,采用了类似门控线性单元(GLU)的结构。它通过三个线性变换层(gate_proj、up_proj和down_proj)和SiLU激活函数,先将输入从隐藏维度扩展到中间维度,经过门控计算后再投影回原始维度。该模块保持了输入输出形状的一致性,演示了如何逐步执行前向传播并验证计算正确性,展示了Transformer模型中常用的前馈神经网络结构。
具体代码与测试如下:

import torch
import torch.nn as nn
from transformers.activations import ACT2FNclass Qwen3MLP(nn.Module):def __init__(self, config):super().__init__()self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act] # siludef forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj# 模拟配置类
class MockConfig:def __init__(self):self.hidden_size = 1024self.intermediate_size = 2048self.hidden_act = "silu"# 完整示例
if __name__ == "__main__":# 1. 创建配置对象config = MockConfig()# 2. 初始化Qwen3MLP模块mlp = Qwen3MLP(config)# 3. 创建测试输入数据batch_size = 2seq_length = 8hidden_size = config.hidden_size  # 1024# 输入张量形状: (batch_size, seq_length, hidden_size)input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MLP 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - intermediate_size: {config.intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"\n输入张量形状: {input_tensor.shape}")# 4. 前向传播with torch.no_grad():output_tensor = mlp(input_tensor)print(f"输出张量形状: {output_tensor.shape}")# 5. 验证输出形状与输入形状一致assert output_tensor.shape == input_tensor.shape, \f"输出形状 {output_tensor.shape} 与输入形状 {input_tensor.shape} 不一致"print("\n=== MLP 层内部组件 ===")print(f"gate_proj 权重形状: {mlp.gate_proj.weight.shape}")print(f"up_proj 权重形状: {mlp.up_proj.weight.shape}")print(f"down_proj 权重形状: {mlp.down_proj.weight.shape}")# 6. 逐步计算过程演示print("\n=== 前向传播步骤 ===")with torch.no_grad():# 第一步: 门控投影gate_output = mlp.gate_proj(input_tensor)print(f"1. gate_proj 输出形状: {gate_output.shape}")# 第二步: 激活函数gate_activated = mlp.act_fn(gate_output)print(f"2. 激活函数后形状: {gate_activated.shape}")# 第三步: 上投影up_output = mlp.up_proj(input_tensor)print(f"3. up_proj 输出形状: {up_output.shape}")# 第四步: 门控线性单元 (GLU)glu_output = gate_activated * up_outputprint(f"4. GLU 输出形状: {glu_output.shape}")# 第五步: 下投影final_output = mlp.down_proj(glu_output)print(f"5. down_proj 输出形状: {final_output.shape}")# 验证与直接调用forward的结果一致direct_output = mlp(input_tensor)assert torch.allclose(final_output, direct_output, atol=1e-6), "逐步计算结果与直接调用不一致"print("✓ 逐步计算结果与直接调用结果一致")print("\n=== 示例完成 ===")print(f"MLP 成功处理了形状为 {input_tensor.shape} 的输入,输出形状为 {output_tensor.shape}")
=== Qwen3MLP 示例 ===
配置信息:- hidden_size: 1024- intermediate_size: 2048- activation: silu输入张量形状: torch.Size([2, 8, 1024])
输出张量形状: torch.Size([2, 8, 1024])=== MLP 层内部组件 ===
gate_proj 权重形状: torch.Size([2048, 1024])
up_proj 权重形状: torch.Size([2048, 1024])
down_proj 权重形状: torch.Size([1024, 2048])=== 前向传播步骤 ===
1. gate_proj 输出形状: torch.Size([2, 8, 2048])
2. 激活函数后形状: torch.Size([2, 8, 2048])
3. up_proj 输出形状: torch.Size([2, 8, 2048])
4. GLU 输出形状: torch.Size([2, 8, 2048])
5. down_proj 输出形状: torch.Size([2, 8, 1024])
✓ 逐步计算结果与直接调用结果一致=== 示例完成 ===
MLP 成功处理了形状为 torch.Size([2, 8, 1024]) 的输入,输出形状为 torch.Size([2, 8, 1024])

Qwen3MoeSparseMoeBlock

Qwen3 模型的稀疏混合专家(Sparse MoE)模块,核心是通过“路由机制+多专家并行计算”提升模型在大参数量下的效率与能力。

Qwen3MoeSparseMoeBlock 处理输入的流程可分为 路由计算→专家选择→并行计算→结果聚合 四步:

1. 路由计算:为每个 token 选专家
  • 输入 hidden_states(形状 [batch_size, seq_length, hidden_size])先展平为 [batch*seq, hidden_size]
  • self.gate(线性层)生成 router_logits(每个 token 对 8 个专家的“匹配分数”);
  • 通过 softmax+topk,为每个 token 选 num_experts_per_tok=2 个“最匹配专家”,并得到归一化的路由权重(决定每个专家对 token 的贡献占比)。
2. 专家选择:标记活跃专家

通过 one_hot 编码生成 expert_mask,标记“哪些专家被哪些 token 选中”;再通过 expert_hit 筛选出至少被一个 token 选中的活跃专家(示例中 8 个专家都有 token 命中)。

3. 并行计算:专家各自处理 token

对每个活跃专家,执行:

  • 筛选出“属于当前专家”的 token(通过 expert_mask 定位);
  • 调用该专家的 Qwen3MoeMLP 层(结构同普通 MLP,但参数量仅服务部分 token),完成“门控投影→激活→上投影→下投影”的计算;
  • 用路由权重对专家输出加权(确保不同专家的贡献按匹配度分配)。
4. 结果聚合:合并所有专家输出

通过 index_add_ 将每个专家处理后的 token 结果,按原始位置合并,最终还原为 [batch_size, seq_length, hidden_size] 的输出。


具体代码与测试如下:

import torch.nn as nn
from transformers.activations import ACT2FN
import torch.nn.functional as Fclass Qwen3MoeMLP(nn.Module):def __init__(self, config, intermediate_size=None):super().__init__()self.config = configself.hidden_size = config.hidden_size  # 512self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size# 256self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# 512, 256self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # 512, 256self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # 256, 512self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_projclass Qwen3MoeSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.num_experts = config.num_experts # 8self.top_k = config.num_experts_per_tok # 2self.norm_topk_prob = config.norm_topk_prob # Trueself.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # 512 -> 8self.experts = nn.ModuleList([Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]) #  512 -> 256 -> 512def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape  # 2, 6, 512hidden_states = hidden_states.view(-1, hidden_dim) # 2, 6, 512 -> 12, 512router_logits = self.gate(hidden_states) # 12 8routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # 12 8routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # 12 2if self.norm_topk_prob:  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)routing_weights = routing_weights.to(hidden_states.dtype)final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) # 12 512expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# 12 2 8    8 2 12 print("expert_mask: \n",expert_mask)expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() # 8print("expert hit: \n",expert_hit)for expert_idx in expert_hit:expert_layer = self.experts[expert_idx]  # Qwen3MoeMLPidx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # 4 4 if expert_idx == 0:print("expert_mask[expert_idx].squeeze(0):",expert_mask[expert_idx].squeeze(0))print("idx:",idx)print("top_x:",top_x)print("hidden_states[None, top_x]:",hidden_states[None, top_x].shape)current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)   # 1, 4, 512 -> 4, 512current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]# 4, 512 * 4, 512 -> 4, 512final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # 2, 6, 512return final_hidden_states, router_logits
class MockConfig:def __init__(self):self.hidden_size = 512self.moe_intermediate_size = 256self.hidden_act = "silu"self.num_experts = 8self.num_experts_per_tok = 2self.norm_topk_prob = Trueimport numpy as np
import random# 设置随机种子以确保可重复性
def set_random_seed(seed=42):"""设置所有随机种子以确保结果可重复"""torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False
# 完整示例
if __name__ == "__main__":set_random_seed(42)config = MockConfig()moe_block = Qwen3MoeSparseMoeBlock(config)batch_size = 2seq_length = 6hidden_size = config.hidden_size  # 512input_tensor = torch.randn(batch_size, seq_length, hidden_size)print("=== Qwen3MoeSparseMoeBlock 示例 ===")print(f"配置信息:")print(f"  - hidden_size: {config.hidden_size}")print(f"  - moe_intermediate_size: {config.moe_intermediate_size}")print(f"  - activation: {config.hidden_act}")print(f"  - num_experts: {config.num_experts}")print(f"  - num_experts_per_tok: {config.num_experts_per_tok}")print(f"  - norm_topk_prob: {config.norm_topk_prob}")print(f"\n输入张量形状: {input_tensor.shape}")with torch.no_grad():output_tensor, router_logits = moe_block(input_tensor)print(f"输出张量形状: {output_tensor.shape}")print(f"路由逻辑形状: {router_logits.shape}")
=== Qwen3MoeSparseMoeBlock 示例 ===
配置信息:- hidden_size: 512- moe_intermediate_size: 256- activation: silu- num_experts: 8- num_experts_per_tok: 2- norm_topk_prob: True输入张量形状: torch.Size([2, 6, 512])
expert_mask: tensor([[[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1],[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]],[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]]])
expert hit: tensor([[0],[1],[2],[3],[4],[5],[6],[7]])
expert_mask[expert_idx].squeeze(0): tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
idx: tensor([0, 0, 0, 1])
top_x: tensor([ 0,  2, 10,  6])
hidden_states[None, top_x]: torch.Size([1, 4, 512])
输出张量形状: torch.Size([2, 6, 512])
路由逻辑形状: torch.Size([12, 8])

文章转载自:

http://IrickwH6.rwzkp.cn
http://f98Jb8h3.rwzkp.cn
http://R3bGlWO4.rwzkp.cn
http://hiu9VyhK.rwzkp.cn
http://QuTzcFrI.rwzkp.cn
http://eBs35SZd.rwzkp.cn
http://RMZHrkeb.rwzkp.cn
http://oALxJl07.rwzkp.cn
http://8TWyb3B4.rwzkp.cn
http://HoRnWSVq.rwzkp.cn
http://GBSrgV1u.rwzkp.cn
http://X9Ced10q.rwzkp.cn
http://86NC0u7w.rwzkp.cn
http://ikeIQxmO.rwzkp.cn
http://dB3mS0Vm.rwzkp.cn
http://jlagqp7L.rwzkp.cn
http://4ylBjehs.rwzkp.cn
http://8OQScHJy.rwzkp.cn
http://XyBYIl5F.rwzkp.cn
http://7aPv0kbw.rwzkp.cn
http://PLmQpmSX.rwzkp.cn
http://loXi00Dx.rwzkp.cn
http://8HiXU08U.rwzkp.cn
http://i1QizKSZ.rwzkp.cn
http://SVwtv1vu.rwzkp.cn
http://SXwwxUCx.rwzkp.cn
http://JNXB4Mjr.rwzkp.cn
http://YLIfK7b7.rwzkp.cn
http://kOyKAYwF.rwzkp.cn
http://HfGwZ5AM.rwzkp.cn
http://www.dtcms.com/a/368759.html

相关文章:

  • React Hooks useContext
  • 【Linux】Linux 的 cp -a 命令的作用
  • 基于FPGA实现CRC校验码算法(以MODBUS中校验码要求为例)verilog代码+仿真验证
  • LeetCode刷题-top100( 矩阵置零)
  • 算法模板(Java版)_DFS与BFS
  • 一分钟了解Modbus 转 IEC61850 网关
  • Webpack 有哪些特性?构建速度?如何优化?
  • 2025精选5款AI视频转文字工具,高效转录秒变文字!
  • 【最新版】发烧级完美解码播放器PureCodec v2025.08.29 中文免费版_电脑播放器影音解码包
  • 阿里云国际代理:阿里云的云数据库是什么?
  • 盲盒抽卡机小程序功能版块设计的合理性评估维度
  • Memory write error at 0x100000. MMU page translation fault
  • 纯血鸿蒙开发入门:2.展示hello world
  • 【1】策略模式 + 模板方法模式的联合应用
  • 突发奇想,还未实践,在Vben5的Antd模式下,将表单从「JS 配置化」改写成「模板可视化」形式(豆包版)
  • Flash Attention:突破大模型推理内存瓶颈的革命性算法
  • 【正则表达式】 正则表达式的分组和引用
  • 具身智能的工程落地:视频-控制闭环的实践路径
  • E+H音叉开关FTL31-AA4M2AAWBJ
  • Android 权限机制默认授权分析
  • 深入理解 HarmonyOS Stage 模型与 UIAbility 生命周期管理
  • Vue3中的数据响应【4】
  • 因泰立科技:用激光雷达重塑智能工厂物流生态
  • 【Windows】通过 runas 命令实现多用户权限测试的完整流程
  • LangChain实战(十六):构建基于SQL数据库的数据分析Agent
  • Struts2 工作总结
  • 软件设计模式之单例模式
  • 小迪安全v2023学习笔记(七十八讲)—— 数据库安全RedisCouchDBH2database未授权CVE
  • 【Go】P2 Golang 常量与变量
  • Leetcode—721. 账户合并【中等】