当前位置: 首页 > news >正文

nanochat 三 模型结构详解

一、模型整体架构
nanochat 模型基于 Transformer 架构构建,核心采用 decoder-only 结构,包含嵌入层、Transformer 编码器块序列、输出层以及旋转位置编码相关的组件,具体结构如下:

class GPT(nn.Module):def __init__(self, config):super().__init__()self.config = config# 核心Transformer结构:包含嵌入层和注意力块列表self.transformer = nn.ModuleDict({# 词嵌入层:将token映射为嵌入向量"wte": nn.Embedding(config.vocab_size, config.n_embd),# 注意力块序列:n_layer个Block串联"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in 
range(config.n_layer)]),})# 输出层:将嵌入向量映射回词表维度self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)# 旋转位置编码相关参数self.rotary_seq_len = config.sequence_len * 10self.head_dim = config.n_embd // config.n_head

从上述代码可见,模型的核心计算流为:输入token → 词嵌入层(wte)→ 多层Block处理 → 输出层(lm_head)→ 词表概率分布。其中,Block 作为基本计算单元,集成了注意力机制和多层感知机,是模型能力的核心载体。

二、核心组件详解
2.1 基础计算单元:Block
每个 Block 包含“因果自注意力层(CausalSelfAttention)”和“多层感知机(MLP)”两个核心模块,其结构逻辑如下:

class Block(nn.Module):def __init__(self, config, layer_idx):super().__init__()# 因果自注意力层:捕捉序列内依赖关系self.attn = CausalSelfAttention(config, layer_idx)# 多层感知机:增强模型非线性表达能力self.mlp = MLP(config)def forward(self, x, cos_sin, kv_cache):# 注意力层+残差连接x = x + self.attn(norm(x), cos_sin, kv_cache)# MLP层+残差连接x = x + self.mlp(norm(x))return x

前向传播过程中,输入向量 x 分别经过注意力层和 MLP 层,且每一层的输出都与输入进行残差相加,有效避免梯度消失问题。同时,输入的 cos_sin 为旋转位置编码参数,kv_cache 为推理时的键值缓存,用于提升解码效率。

2.2 核心能力核心:因果自注意力机制(CausalSelfAttention)
因果自注意力机制是模型实现“预测下一个token”任务的关键,它确保每个位置的预测仅依赖于前面的token(因果约束),同时通过多头注意力(Multi-Head Attention)和分组查询注意力(GQA)提升建模能力与效率。

class CausalSelfAttention(nn.Module):def __init__(self, config, layer_idx):super().__init__()self.layer_idx = layer_idxself.n_head = config.n_head  # 查询头数量self.n_kv_head = config.n_kv_head  # 键值头数量(GQA用)self.n_embd = config.n_embdself.head_dim = self.n_embd // self.n_head  # 每个头的维度# 合法性校验assert self.n_embd % self.n_head == 0assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0# 投影层:将输入映射为Q、K、Vself.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)# 输出投影层:整合多头结果self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

支持 GQA(分组查询注意力):通过减少键值(K/V)头的数量(n_kv_head ≤ n_head),在保证建模能力的同时降低计算和内存开销,平衡了性能与效率。

2.3 前向传播逻辑
前向传播过程可分为“投影与位置编码”“注意力计算”“结果整合”三个阶段,关键步骤如下:

def forward(self, x, cos_sin, kv_cache):B, T, C = x.size()  # B:批次大小, T:序列长度, C:嵌入维度# 1. 投影为Q、K、V并调整形状为(B, T, H, D)q = self.c_q(x).view(B, T, self.n_head, self.head_dim)k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)# 2. 应用旋转位置编码(增强位置敏感性)cos, sin = cos_sinq, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)q, k = norm(q), norm(k)  # 归一化防止数值爆炸# 3. 调整维度为(B, H, T, D)以适配注意力计算q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)# 4. KV缓存管理(推理时复用历史K/V,提升效率)if kv_cache is not None:k, v = kv_cache.insert_kv(self.layer_idx, k, v)# 5. 注意力计算(根据场景选择因果掩码)enable_gqa = self.n_head != self.n_kv_head  # 判断是否启用GQAif kv_cache is None or Tq == Tk:  # 训练或序列长度匹配时y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)elif Tq == 1:  # 推理单token时(全注意力)正常推理场景y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)else:  # 推理多token块时(混合掩码)attn_mask = self._build_attn_mask(Tq, Tk, q.device)y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)# 6. 整合多头结果并投影输出y = y.transpose(1, 2).contiguous().view(B, T, -1)y = self.c_proj(y)return y

1)当if kv_cache is None or Tq == Tk 是首次推理或者训练情况,训练场景将本质串行的自回归任务转化为并行计算,使海量数据训练成为可能。并行化分为“序列内并行”和“批次间并行”两个层次。

逻辑说明:例如模型输入最大长度1024,就取1024训练数据输入。
序列并行计算:当这个输入序列被送入模型时,模型内部的 Transformer 结构会一次性为这 1023 个位置并行地计算出各自的 Logits。
* 位置 1 的 Logits:基于 [T_1] 预测 T_2
* 位置 2 的 Logits:基于 [T_1, T_2] 预测 T_3
* …
* 位置 1023 的 Logits:基于 [T_1, ..., T_{1023}] 预测 T_{1024}

  • 损失计算:计算这 1023 个预测的交叉熵损失,然后取平均,得到这一个序列的损失。
    关键点:虽然是并行计算,但因果注意力机制确保了每个位置的预测都只依赖于它前面的 token,完美模拟了“预测下一个词”这个任务,同时享受了并行计算的巨大速度优势。

批次之间的并行

  • GPU 的显存足够大,不会一次只处理一个 1024 长度的序列,可以同时处理很多个这样的序列。
  • 实际操作
    • 假设我们的 Batch Size(批次大小)是 64。
    • 我们会从数据集中同时取出 64 个独立的、长度为 1024 的序列。
    • 我们把这 64 个序列打包成一个三维张量,形状是 [64, 1023](64 个序列,每个序列 1023 个输入 token)。
    • 这个巨大的张量被一次性送入 GPU。
    • GPU 的数千个计算核心会高度并行地处理这 64 个序列中的所有 token。
  • 损失计算
    • 模型会为这个批次输出一个形状为 [64, 1023, Vocab_Size] 的 Logits 张量(Vocab_Size 是词表大小)。
    • 我们会计算这 64 个序列中,每一个 token 预测的损失。
    • 总共有 64 * 1023 = 65472 个独立的预测任务。
    • 最终,我们将这 65472 个损失值全部加起来,再取平均,得到这个批次的最终损失值。
      这个最终的损失值,代表了模型在当前这 64 个样本上的平均表现。反向传播和参数更新就是基于这个“平均损失”来进行的。

2)正常推理场景会使用kv_cache来做单个token预测
在标准的自回归解码中,模型有一个“记忆”,叫做 KV Cache。在 Transformer 的自注意力机制中,为了计算当前 token 的注意力,需要知道它前面所有 token 的 Key (K) 和 Value (V) 矩阵。为了避免在每生成一个新 token 时都重新计算前面所有 token 的 K 和 V(这会非常慢),我们会把已经计算过的 K 和 V 缓存起来。

3)推理中的Chunk Decoding(分块解码)
是一种在标准自回归解码和完全并行解码之间的折中方案,它主要为了解决长序列推理时的内存瓶颈问题。

Chunk Decoding 的核心思想是:不必记住每一个历史 token,把历史“打包”成几个“摘要块”,只保留摘要信息,从而节省内存。它把生成长序列的过程,分成了多个“块”来处理。
Chunk Decoding 的工作流程:
设定 chunk_ size=4 ,初始阶段(和标准解码一样):

  • 输入 Prompt,模型生成第一个 token t_1 。缓存 t_1 的 KV。
  • 输入 [t_1] ,生成 t_2 。缓存 t_1, t _2 的 KV。
  • 输入 [t_1, t_2] ,生成 t_3 。缓存 t_1, t _2, t_3 的 KV。
  • 输入 [t_1, t_2, t_3] ,生成 t_4 。缓存 t_1, t _2, t_3, t_4 的 KV。
  • KV Cache 的大小是 4,和标准解码没区别。现在进入第一个 Chunk,现在要生成 t_5 。此时,KV Cache 里有 [t_1, t_2, t_3, t_4] 的 KV。

关键操作:我们把 [t_1, t_2, t_3, t_4] 这个历史序列,作为一个“摘要块”来处理。用一个特殊的“总结”操作(比如用一个小型的注意力池化层)来压缩这 4 个 token 的 KV,生成一个更短的、信息密度更高的“摘要KV”。
内存释放:我们丢弃 [t_1, t_2, t_3, t_4] 原始的、占内存的 KV,只保留那个压缩后的“摘要KV”。

  • 生成 t_5 :现在,模型在计算 t_5 的注意力时,它的上下文变成了 [摘要KV] ,而不是 [t_1, t_2, t_3, t_4] 。它基于这个摘要和当前状态生成 t_5 ,并缓存 t_5 的 KV。
  • 生成 t_6 ,缓存 [摘要KV, t_5] 。
  • 生成 t_7 ,缓存 [摘要KV, t_5, t_6] 。
  • 生成 t_8 ,缓存 [摘要KV, t_5, t_6, t_7] 。
  • 进入第二个 Chunk,又攒够了 4 个新的 token( t_5 到 t_8 ,注意 t_4 已经被摘要了)。我们再次对 [t_5, t_6, t_7, t_8] 进行压缩,生成第二个“摘要KV”。丢弃 t_5 到 t_8 的原始 KV。现在的 KV Cache 变成了 [摘要KV_1, 摘要KV_2] 。这个过程不断重复,无论序列生成多长,KV Cache 的大小始终被控制在 chunk_size 的水平(或者加上几个摘要块的大小),而不会无限增长。

优点
节省显存,支持超长序列生成:这是它最大的优点。它使得在消费级显卡上生成长达数万甚至数十万 token 的文本成为可能。
速度可能更快:在生成长序列时,由于每次自注意力的计算量(与 KV Cache 的长度成正比)被限制了,解码速度可能会比标准解码更快。
缺点
牺牲精度,可能导致质量下降:这是它最大的代价。将历史信息压缩成摘要,必然会丢失细节。这可能导致模型“忘记”了文章开头的一些关键信息,使得长文本的连贯性和一致性下降。比如,文章开头提到主角叫“张三”,生成到后面可能会忘记,改叫“李四”。
实现复杂度更高:它需要修改模型的解码逻辑,实现摘要生成和 KV Cache 的管理,比标准解码要复杂。

应用场景
Chunk Decoding 并不是通用解决方案,它主要用在那些对生成长度有极高要求,但对完美细节和一致性要求相对宽松的场景。
长文档摘要:输入一本书,输出一份摘要。
代码生成/补全:生成长篇的代码文件。
长篇故事/报告写作:需要生成几万字的初稿。

2.4 非线性增强模块:多层感知机(MLP)
MLP 模块用于对注意力层的输出进行非线性变换,增强模型的表达能力。其结构采用“输入扩展-非线性激活-输出压缩”的经典设计,具体如下:

class MLP(nn.Module):def __init__(self, config):super().__init__()# 输入扩展层:将嵌入维度放大4倍self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)# 输出压缩层:将维度还原为嵌入维度self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)def forward(self, x):x = self.c_fc(x)x = F.relu(x).square()  # 非线性激活:ReLU+平方x = self.c_proj(x)return x

关于嵌入维度放大倍数(4倍)的设计:扩展倍数越大,模型表达能力越强,但参数量和计算量也随之增加,易导致过拟合;倍数越小则更轻量,但拟合能力可能不足。实践表明,4倍是性能与计算开销的平衡选择,为行业通用经验值。

2.5 位置信息建模:旋转位置编码(Rotary Embedding)
采用旋转位置编码为token注入位置信息,通过对查询(Q)和键(K)的旋转操作,使模型捕捉序列的顺序依赖关系。预计算不同序列长度和头维度下的旋转参数(cos/sin),避免重复计算,提升效率:

def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):device = device or self.transformer.wte.weight.device# 计算频率:每两个维度为一组channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)inv_freq = 1.0 / (base ** (channel_range / head_dim))# 时间步长与频率矩阵计算t = torch.arange(seq_len, dtype=torch.float32, device=device)freqs = torch.outer(t, inv_freq)cos, sin = freqs.cos(), freqs.sin()cos, sin = cos.bfloat16(), sin.bfloat16()# 增加batch和head维度以支持广播cos, sin = cos[None, :, None, :], sin[None, :, None, :]return cos, sin

2.6 编码应用

def apply_rotary_emb(x, cos, sin):assert x.ndim == 4  # 要求输入为(B, T, H, D)四维张量d = x.shape[3] // 2  # 将头维度分为两半x1, x2 = x[..., :d], x[..., d:]  # 按最后一维分割# 旋转计算y1 = x1 * cos + x2 * siny2 = x1 * (-sin) + x2 * cosout = torch.cat([y1, y2], 3)  # 拼接还原头维度out = out.to(x.dtype)return out
http://www.dtcms.com/a/615561.html

相关文章:

  • 专门做水产海鲜的网站吗广东东莞厚街买婬女
  • 网站开发用php还pyt h on网站首页默认的文件名一般为
  • 园林网站免费模板国外做兼职网站
  • 医院营销型网站建设网站开发技术主题
  • 吉林市建设工程档案馆网站做网站优化如何写方案
  • 微信公众号里的网站怎么做的做公司产品展示网站
  • 做个简单的网站app开发的流程
  • 做网站高校视频单位装专用的网站网页归档
  • 徐州有哪些制作网站的公司wordpress 获取最新文章
  • 免费网站应用软件制作网页倒计时按钮
  • 在公司网站建设会议上的汇报有没有哪种网站推荐一下
  • 数 码 管
  • 黑彩网站怎么做零一云主机
  • 电商网站需求分析内蒙古兴泰建设集团信息化网站
  • 平邑网站建设可以用手机建设网站吗
  • 龙岩做网站的公司一个网站的年维护费
  • MySQL 并发控制机制详解:锁机制、MVCC 与 Read View
  • 学做php网站有哪些怎么做网站投放广告
  • 泾阳网站建设网站建设 百度经验
  • 注册网站花的钱做会计分录河北保定最新消息
  • 海网站建设生产厂家哪家好广告公司现状
  • 服务器添加网站asp.net企业网站建设
  • 中国轻工建设协会网站最方便在线网站开发
  • 初中信息技术 网站制作无锡谁会建商务网站
  • 自己做网站需要学些什么微信开发网站建设程序
  • 移动端网站建设服务商中文网站开发软件
  • 从“学习到学历”与从“学历到学习”
  • 卫星通信中的交叉极化干扰及其在链路预算中的影响
  • 网站表现形式做公司网站大概需要多少钱啊
  • AstraOS 1.90 基础架构版(续)