当前位置: 首页 > news >正文

从qwen3-next学习大模型前沿架构

官方博客:https://zhuanlan.zhihu.com/p/1949631642294522105

源码:https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen3_next

在这里插入图片描述
总览:
依然是transformer的decoder形式。pre-norm,attention,norm,FFN。
最特别的是,有75%的层采用Gated DeltaNet,其余仍然是标注注意力。所以下面是3,上面是1.

下面从下到上看看有什么变化。

1.zero-centered RMSNorm

原论文
RMSNorm是均方根归一化,移除了层归一化中的均值的计算部分。

实现公式:
x^i=xi1n∑i=1nxi2+ϵ\hat{x}_i = \frac{x_i}{\sqrt{\frac{1}{n}\sum_{i=1}^n x_i^2 + \epsilon}} x^i=n1i=1nxi2+ϵxi

标准的、真正意义上的zero-centered RMSNorm应该是减去均值,如下:

def _norm(self, x):mu = x.mean(-1, keepdim=True)x_centered = x - muvariance = x_centered.pow(2).mean(-1, keepdim=True)return x_centered / torch.sqrt(variance + self.eps)

但是qwen3-next的实现仍然是标准的RMSNorm,不过初始化偏置为0,而不是像标准实现一样初始化为ones。所以训练开始缩放因子为1,是一种参数初始化策略。设计重点在于 训练稳定性 与 数值精度控制。

代码如下:

class Qwen3NextRMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.zeros(dim))  # 这里def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float())output = output * (1.0 + self.weight.float())return output.type_as(x)

2.Gated DeltaNet

源论文:https://arxiv.org/pdf/2412.06464

在这里插入图片描述

公式:

在这里插入图片描述
解释:
S:记忆,t是时间步;α是门控衰减系数,值域在0-1之间,β是更新强度系数,值域同上,控制新信息的写入程度。

优化点:

Gated DeltaNet吸收了Manba2和DeltaNet的优点,前者是一刀切所有的记忆,后者无法快速清理大量记忆。
在这里插入图片描述
在这里插入图片描述

与标准注意力和其他线性注意力的区别:

对比维度标准注意力(Transformer)线性注意力(Mamba2/DeltaNet)Gated DeltaNet 注意力
计算复杂度$ O(L^2 \cdot d_k) $(二次,慢)$ O(L \cdot d_k d_v) $(线性,快)$ O(L \cdot d_k d_v) $(线性,快)
记忆清理无主动清理,靠 softmax 权重筛选Mamba2:全局衰减(乱删);DeltaNet:精准删除(慢清)门控 + 精准(又快又准)
长文本能力弱($ L $ 大时算不动)强($ L $ 大也能算)但效果有短板强($ L $ 大且效果优)
并行训练效率中等(块内并行)中等(分块并行 + 数学分解)优(分块并行 + 简化计算)
适用场景短文本、高精度任务(如翻译、摘要)长文本、效率优先任务(如日志分析)长文本 + 高精度任务(如长报告问答、代码理解)

网络结构

GatedDeltaNet((silu): SiLU()(q_proj): Linear(in_features=512, out_features=1024, bias=False)(k_proj): Linear(in_features=512, out_features=1024, bias=False)(v_proj): Linear(in_features=512, out_features=2048, bias=False)(b_proj): Linear(in_features=512, out_features=4, bias=False)(a_proj): Linear(in_features=512, out_features=4, bias=False)(q_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)(k_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)(v_conv1d): ShortConvolution(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048, bias=False, activation=silu)(g_proj): Linear(in_features=512, out_features=2048, bias=False)(o_norm): FusedRMSNormSwishGate(512, eps=1e-05)(o_proj): Linear(in_features=2048, out_features=512, bias=False)
)

代码实现

代码解读:https://www.doubao.com/thread/w7e2fecc6eebc2029

class Qwen3NextGatedDeltaNet(nn.Module):def __init__(self, config: Qwen3NextConfig, layer_idx: int):super().__init__()self.hidden_size = config.hidden_sizeself.num_v_heads = config.linear_num_value_headsself.num_k_heads = config.linear_num_key_headsself.head_k_dim = config.linear_key_head_dimself.head_v_dim = config.linear_value_head_dimself.key_dim = self.head_k_dim * self.num_k_headsself.value_dim = self.head_v_dim * self.num_v_headsself.conv_kernel_size = config.linear_conv_kernel_dimself.layer_idx = layer_idxself.activation = config.hidden_actself.act = ACT2FN[config.hidden_act]self.layer_norm_epsilon = config.rms_norm_eps# QKVself.conv_dim = self.key_dim * 2 + self.value_dimself.conv1d = nn.Conv1d(in_channels=self.conv_dim,out_channels=self.conv_dim,bias=False,kernel_size=self.conv_kernel_size,groups=self.conv_dim,padding=self.conv_kernel_size - 1,)# projection of the input hidden statesprojection_size_qkvz = self.key_dim * 2 + self.value_dim * 2projection_size_ba = self.num_v_heads * 2self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False)self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False)# time step projection (discretization)# instantiate once and copy inv_dt in init_weights of PretrainedModelself.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))A = torch.empty(self.num_v_heads).uniform_(0, 16)self.A_log = nn.Parameter(torch.log(A))self.norm = (Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)if FusedRMSNormGated is Noneelse FusedRMSNormGated(self.head_v_dim,eps=self.layer_norm_epsilon,activation=self.activation,device=torch.cuda.current_device(),dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(),))self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)self.causal_conv1d_fn = causal_conv1d_fnself.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_updateself.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_ruleself.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_ruleif not is_fast_path_available:logger.warning_once("The fast path is not available because one of the required library is not installed. Falling back to ""torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"" https://github.com/Dao-AILab/causal-conv1d")def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):"""Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`."""new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (self.num_k_heads,2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,)new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)mixed_ba = mixed_ba.view(*new_tensor_shape_ba)split_arg_list_qkvz = [self.head_k_dim,self.head_k_dim,(self.num_v_heads // self.num_k_heads * self.head_v_dim),(self.num_v_heads // self.num_k_heads * self.head_v_dim),]split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)b = b.reshape(b.size(0), b.size(1), self.num_v_heads)a = a.reshape(a.size(0), a.size(1), self.num_v_heads)return query, key, value, z, b, adef forward(self,hidden_states: torch.Tensor,cache_params: Optional[Qwen3NextDynamicCache] = None,cache_position: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.Tensor] = None,):hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)# Set up dimensions for reshapes laterbatch_size, seq_len, _ = hidden_states.shapeuse_precomputed_states = (cache_params is not Noneand cache_params.has_previous_stateand seq_len == 1and cache_position is not None)# getting projected states from cache if it existsif cache_params is not None:conv_state = cache_params.conv_states[self.layer_idx]recurrent_state = cache_params.recurrent_states[self.layer_idx]projected_states_qkvz = self.in_proj_qkvz(hidden_states)projected_states_ba = self.in_proj_ba(hidden_states)query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))mixed_qkv = torch.cat((query, key, value), dim=-1)mixed_qkv = mixed_qkv.transpose(1, 2)if use_precomputed_states:# 2. Convolution sequence transformation# NOTE: the conv state is updated in `causal_conv1d_update`mixed_qkv = self.causal_conv1d_update(mixed_qkv,conv_state,self.conv1d.weight.squeeze(1),self.conv1d.bias,self.activation,)else:if cache_params is not None:conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))cache_params.conv_states[self.layer_idx] = conv_stateif self.causal_conv1d_fn is not None:mixed_qkv = self.causal_conv1d_fn(x=mixed_qkv,weight=self.conv1d.weight.squeeze(1),bias=self.conv1d.bias,activation=self.activation,seq_idx=None,)else:mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])mixed_qkv = mixed_qkv.transpose(1, 2)query, key, value = torch.split(mixed_qkv,[self.key_dim,self.key_dim,self.value_dim,],dim=-1,)query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)beta = b.sigmoid()# If the model is loaded in fp16, without the .float() here, A might be -infg = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)if self.num_v_heads // self.num_k_heads > 1:query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)if not use_precomputed_states:core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(query,key,value,g=g,beta=beta,initial_state=None,output_final_state=cache_params is not None,use_qk_l2norm_in_kernel=True,)else:core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(query,key,value,g=g,beta=beta,initial_state=recurrent_state,output_final_state=cache_params is not None,use_qk_l2norm_in_kernel=True,)# Update cacheif cache_params is not None:cache_params.recurrent_states[self.layer_idx] = last_recurrent_statez_shape_og = z.shape# reshape input data into 2D tensorcore_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])z = z.reshape(-1, z.shape[-1])core_attn_out = self.norm(core_attn_out, z)core_attn_out = core_attn_out.reshape(z_shape_og)core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)output = self.out_proj(core_attn_out)return output

3. Gated Attention

特点:

在MHA上加了一个sigmoid激活函数,用于门控;
每个 head 内部的 q/k 向量做归一化,即QKnorm;
支持GQA。

代码实现:

class Qwen3NextAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: Qwen3NextConfig, layer_idx: int):super().__init__()self.config = configself.layer_idx = layer_idxself.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)self.num_key_value_groups = config.num_attention_heads // config.num_key_value_headsself.scaling = self.head_dim**-0.5self.attention_dropout = config.attention_dropoutself.is_causal = Trueself.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias)self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")def forward(self,hidden_states: torch.Tensor,position_embeddings: tuple[torch.Tensor, torch.Tensor],attention_mask: Optional[torch.Tensor],past_key_values: Optional[Cache] = None,cache_position: Optional[torch.LongTensor] = None,**kwargs: Unpack[FlashAttentionKwargs],) -> tuple[torch.Tensor, Optional[torch.Tensor]]:input_shape = hidden_states.shape[:-1]hidden_shape = (*input_shape, -1, self.head_dim)query_states, gate = torch.chunk(self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1)gate = gate.reshape(*input_shape, -1)query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)cos, sin = position_embeddingsquery_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)if past_key_values is not None:# sin and cos are specific to RoPE models; cache_position needed for the static cachecache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)attention_interface: Callable = eager_attention_forwardif self.config._attn_implementation != "eager":attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]attn_output, attn_weights = attention_interface(self,query_states,key_states,value_states,attention_mask,dropout=0.0 if not self.training else self.attention_dropout,scaling=self.scaling,**kwargs,)attn_output = attn_output.reshape(*input_shape, -1).contiguous()attn_output = attn_output * torch.sigmoid(gate)attn_output = self.o_proj(attn_output)return attn_output, attn_weights

4.MTP

deepseek开始大规模采用。一次输入,预测多步。

https://zhuanlan.zhihu.com/p/15037286337

https://medium.com/@bingqian/understanding-multi-token-prediction-mtp-in-deepseek-v3-ed634810c290


文章转载自:

http://UhyYyKhk.ybgpk.cn
http://tJUMK5Jo.ybgpk.cn
http://Eg8Y0G5c.ybgpk.cn
http://gm09rjNY.ybgpk.cn
http://gLCnYLJm.ybgpk.cn
http://WmztEHOe.ybgpk.cn
http://RYAPW5AV.ybgpk.cn
http://6VggWgWY.ybgpk.cn
http://h4Wrl8bb.ybgpk.cn
http://RQtlfxqv.ybgpk.cn
http://hjSrNkJ7.ybgpk.cn
http://olh3tEqt.ybgpk.cn
http://SPSwFLfq.ybgpk.cn
http://ksL9oRRd.ybgpk.cn
http://PcjBMPec.ybgpk.cn
http://GIY8b1wB.ybgpk.cn
http://faLbOcUV.ybgpk.cn
http://tDgwvRzY.ybgpk.cn
http://F0c5EXLn.ybgpk.cn
http://tOwoV8KN.ybgpk.cn
http://7pro1KjV.ybgpk.cn
http://BZnHKG1H.ybgpk.cn
http://FdvSR2wj.ybgpk.cn
http://WGo9V0F1.ybgpk.cn
http://oJnDUHYI.ybgpk.cn
http://iz4e0OeO.ybgpk.cn
http://qazzzmKP.ybgpk.cn
http://bwE9WPfZ.ybgpk.cn
http://UEC3jSX6.ybgpk.cn
http://WGQHFtAb.ybgpk.cn
http://www.dtcms.com/a/382987.html

相关文章:

  • 【Linux】深入Linux多线程架构与高性能编程
  • Python爬虫-爬取拉勾网招聘数据
  • Python|Pyppeteer解决Pyppeteer启动后,页面一直显示加载中,并显示转圈卡死的问题(37)
  • C++_STL和数据结构《1》_STL、STL_迭代器、c++中的模版、STL_vecto、列表初始化、三个算法、链表
  • 【计算机网络 | 第16篇】DNS域名工作原理
  • C++算法题中的输入输出形式(I/O)
  • 【算法详解】:编程中的“无限”可能,驾驭超大数的艺术—高精度算法
  • Linux基础开发工具(gcc/g++,yum,vim,make/makefile)
  • NLP:Transformer之多头注意力(特别分享4)
  • arm芯片的功能优化方案
  • 【C++】动态数组vector的使用
  • 软件工程实践三:RESTful API 设计原则
  • [硬件电路-221]:PN结的电阻率是变化的,由无穷大到极小,随着控制电压的变化而变化,不同的电场方向,电阻率的特征也不一样,这正是PN的最有价值的地方。
  • 用户争夺与智能管理:定制开发开源AI智能名片S2B2C商城小程序的战略价值与实践路径
  • 5 遥感与机器学习第三方库安装
  • 告别双系统——WSL2+UBUNTU在WIN上畅游LINUX
  • 【开题答辩全过程】以 SpringBoot的淘宝购物优惠系统的设计与实现为例,包含答辩的问题和答案
  • SpringMVC @RequestMapping的使用演示和细节 详解
  • 后端json数据反序列化枚举类型不匹配的错误
  • 【贪心算法】day10
  • vue动画内置组件
  • 构建完整的RAG生态系统并优化每个组件
  • 20250914-03: Langchain概念:提示模板+少样本提示
  • Java 字符编码问题,怎么优雅地解决?
  • CopyOnWrite
  • 【Ambari监控】监控数据接口查询方法
  • shell 脚本:正则表达式
  • 可调精密稳压器的原理
  • Altium Designer(AD)PCB打孔
  • React 状态管理