百刀打造ChatGPT:nanochat极简LLM全栈实现深度解析
当ChatGPT横空出世,无数开发者在惊叹其强大能力的同时,也被其天文数字般的训练成本所震慑。动辄上千万美元的算力投入,让大模型训练成为了科技巨头的专利。但如果我告诉你,只需100美元,你就能从零开始训练一个属于自己的ChatGPT,你会相信吗?
这不是天方夜谭,而是Andrej Karpathy(特斯拉前AI总监、OpenAI创始团队成员)最新开源项目nanochat带来的革命性突破。这个项目用不到8000行代码,在4小时内完成了从数据准备、分词器训练、模型预训练、指令微调到Web部署的全流程,真正实现了"The best ChatGPT that $100 can buy"(百刀能买到的最好ChatGPT)。
更令人惊叹的是,这不是一个玩具项目。nanochat在CORE评测集上达到了0.22的分数,在多项基准测试中表现不俗,证明了在极限预算下打造可用LLM的可行性。
本文将深入剖析nanochat的技术架构、核心实现和工程智慧,带你一窥现代LLM全栈开发的精髓。无论你是想学习LLM原理的研究者,还是希望构建垂直领域模型的工程师,这篇文章都将为你揭开大模型神秘面纱的重要一角。
一、技术架构全景:极简主义的工程美学
1.1 项目定位:可黑客化的全栈LLM基线
nanochat的核心理念可以用三个关键词概括:minimal(极简)、hackable(可黑客化)、full-stack(全栈)。
不同于Transformers、DeepSpeed等"大而全"的框架,nanochat刻意避免了过度工程化。整个项目结构清晰到令人愉悦:
nanochat/
├── nanochat/ # 核心库(不到2000行)
│ ├── gpt.py # GPT模型实现(320行)
│ ├── engine.py # 高效推理引擎(350行)
│ ├── tokenizer.py # 双实现分词器(400行)
│ ├── dataloader.py # 流式数据加载(50行)
│ ├── muon.py # Muon优化器(190行)
│ └── ...
├── scripts/ # 训练/评估脚本
├── tasks/ # 评测任务实现
├── rustbpe/ # Rust高性能分词器
└── speedrun.sh # 一键训练脚本
这种极简设计带来了巨大的认知优势:
-
可读性:一个周末就能通读全部核心代码
-
可调试:没有多层抽象的黑盒,每一行都清晰可见
-
可定制:想改什么就改什么,不用担心牵一发动全身
-
可学习:每个决策都有明确的工程考量,是绝佳的教学材料
1.2 四阶段训练流水线
nanochat采用了经典的四阶段训练范式,这也是现代LLM的标准做法:
┌─────────────┐ ┌──────────────┐ ┌─────────────┐ ┌──────────┐
│ 分词器训练 │ --> │ 基座预训练 │ --> │ 中期微调 │ --> │ SFT微调 │
│ Tokenizer │ │ Base Model │ │ Mid-training│ │ Chat │
└─────────────┘ └──────────────┘ └─────────────┘ └──────────┘2B字符 11B tokens 对话格式 指令对齐65K词表 561M参数 特殊token 任务混合
阶段1:分词器训练(Tokenizer Training)
-
在20亿字符的FineWeb-Edu数据上训练BPE分词器
-
词表大小:65,536(2^16),平衡了效率与表达能力
-
双实现:Rust训练(高性能) + tiktoken推理(高效)
-
平均压缩率:4.8字符/token
阶段2:基座预训练(Base Pretraining)
-
模型规模:d20深度(561M参数)
-
训练数据:112亿tokens(遵循Chinchilla定律20:1)
-
训练时长:~2.5小时(8xH100)
-
目标:学习语言基础知识、常识推理
阶段3:中期微调(Mid-training)
-
引入对话格式的特殊tokens:
<|user_start|>
,<|assistant_start|>
等 -
教会模型工具使用(calculator tool)
-
适应多轮对话结构
-
训练时长:~30分钟
阶段4:监督微调(SFT)
-
任务混合:ARC、GSM8K、HumanEval、SmolTalk
-
领域对齐:让模型学会"如何表现"
-
训练时长:~20分钟
-
可选:强化学习(RL)进一步提升数学推理能力
整个流程设计巧妙地平衡了"能力获取"与"行为塑造",每个阶段都有明确的目标和可量化的评估指标。
1.3 依赖管理:拥抱现代工具链
nanochat在依赖管理上采用了2025年的最佳实践:
# pyproject.toml
[project]
dependencies = ["torch>=2.8.0", # PyTorch 2.x的编译优化"tokenizers>=0.22.0","tiktoken>=0.11.0", # OpenAI的高效tokenizer"datasets>=4.0.0", # HuggingFace数据集"fastapi>=0.117.1", # Web服务...
][build-system]
requires = ["maturin>=1.7"] # Rust-Python互操作
build-backend = "maturin"
特别值得注意的几个设计:
-
uv包管理器:取代pip,速度提升10-100倍,依赖解析更智能
-
Rust融合:用Maturin无缝集成Rust模块,性能关键部分用Rust重写
-
CUDA 12.8:明确指定PyTorch的CUDA版本,避免兼容性问题
-
最小化依赖:仅2004行依赖(uv.lock),远少于典型项目
1.4 核心技术选型理念
nanochat的每一个技术选择都经过深思熟虑:
技术点 | 选择 | 理由 |
---|---|---|
模型架构 | GPT-style Transformer | 简单、稳定、易于理解 |
注意力机制 | MQA(Multi-Query Attention) | 推理速度快,显存占用低 |
激活函数 | ReLU² | 训练稳定,计算高效 |
位置编码 | RoPE(Rotary Position Embedding) | 外推性好,无需学习参数 |
归一化 | RMSNorm(无可学习参数) | 训练稳定,减少参数量 |
优化器 | Muon(矩阵) + AdamW(嵌入层) | Muon收敛更快,AdamW稳定性好 |
分词器 | GPT-4风格BPE | 压缩率高,通用性强 |
数据集 | FineWeb-Edu | 高质量教育内容,公开可得 |
这些选择背后的逻辑是:优先选择简单、稳定、已验证的技术,而非追求最新、最复杂的方案。这正是nanochat能在极短代码量内实现完整功能的关键。
二、GPT模型实现:现代Transformer的极简重构
2.1 模型配置:深度优先的参数分配
nanochat采用了一个非常有趣的参数配置策略:
@dataclass
class GPTConfig:sequence_len: int = 1024vocab_size: int = 50304n_layer: int = 12 # 深度n_head: int = 6 # 查询头数量n_kv_head: int = 6 # KV头数量(MQA)n_embd: int = 768 # 嵌入维度
关键设计决策:
1. 深度与宽度的trade-off
nanochat使用公式 model_dim = depth * 64
来计算嵌入维度。对于d20模型:
-
深度(n_layer)= 20
-
宽度(n_embd)= 20 × 64 = 1280
-
头维度(head_dim)= 1280 / 10 = 128
这种"深度优先"策略基于研究发现:在相同参数量下,更深的网络往往表现更好。aspect ratio(宽度/深度)保持在64左右是一个经验值。
2. Multi-Query Attention(MQA)
# 传统Multi-Head Attention
n_head = 10 # 10个Query头
n_kv_head = 10 # 10个KV头# MQA配置
n_head = 10 # 10个Query头
n_kv_head = 10 # 10个KV头(可以设为1以节省显存)
MQA的核心思想是让多个Query头共享同一组Key和Value,在推理时可以显著减少KV Cache的显存占用,几乎不损失性能。
3. 词表大小填充
注意到 vocab_size = 50304
而不是整数?这是因为:
-
实际训练的词表可能是65536(2^16)
-
但代码中预留了可调整空间,填充到64的倍数以优化GPU计算
2.2 核心组件实现
2.2.1 RMSNorm:极简的归一化
def norm(x):# 纯函数式RMSNorm,无可学习参数return F.rms_norm(x, (x.size(-1),))
这可能是你见过最简洁的归一化实现。传统LayerNorm有两个可学习参数(scale和bias),而RMSNorm:
-
只进行均方根归一化,不减均值
-
完全无参数,减少4-8M参数量
-
训练稳定性不亚于LayerNorm
2.2.2 RoPE:相对位置编码
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000):# 计算旋转频率channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32)inv_freq = 1.0 / (base ** (channel_range / head_dim))t = torch.arange(seq_len, dtype=torch.float32)freqs = torch.outer(t, inv_freq)cos, sin = freqs.cos(), freqs.sin()return cos, sindef apply_rotary_emb(x, cos, sin):d = x.shape[3] // 2x1, x2 = x[..., :d], x[..., d:]y1 = x1 * cos + x2 * siny2 = x1 * (-sin) + x2 * cosreturn torch.cat([y1, y2], 3)
RoPE的精妙之处:
-
相对位置感知:通过旋转矩阵编码位置,自然支持相对位置建模
-
外推能力强:训练在2048长度,推理时可以扩展到更长
-
无额外参数:位置信息通过数学变换注入,不占用参数空间
实现细节:
-
预计算cos/sin矩阵,避免重复计算
-
存储在bfloat16,节省显存
-
缓存10倍序列长度,支持长文本推理
2.2.3 注意力机制:Flash Attention集成
def forward(self, x, cos_sin, kv_cache):# 投影到Q、K、Vq = self.c_q(x).view(B, T, self.n_head, self.head_dim)k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)# RoPE + QK归一化q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)q, k = norm(q), norm(k) # QK norm提升训练稳定性# MQA:复制KV头以匹配Q头数量k, v = repeat_kv(k, self.n_head // self.n_kv_head), repeat_kv(v, ...)# Flash Attention(自动选择最优实现)y = F.scaled_dot_product_attention(q, k, v, is_causal=True)return self.c_proj(y)
这段代码有几个值得玩味的细节:
1. QK Normalization
q, k = norm(q), norm(k)
在Q和K上再次应用归一化,这是Gemma等新模型的做法,能提升训练稳定性,防止注意力分数爆炸。
2. 自动优化的Flash Attention
F.scaled_dot_product_attention(q, k, v, is_causal=True)
PyTorch 2.x的这个函数会自动选择:
-
Flash Attention 2(最优实现)
-
Memory-efficient attention(显存受限时)
-
标准实现(兜底方案)
无需手动管理,性能提升2-4倍!
3. KV Cache处理
if kv_cache is not None:k, v = kv_cache.insert_kv(self.layer_idx, k, v)
推理时使用KV Cache是标配优化,避免重复计算历史token的K和V。nanochat的实现支持:
-
自动扩容(动态增长)
-
批量prefill(一次性计算prompt)
-
渐进式解码(逐token生成)
2.2.4 MLP:ReLU²激活
class MLP(nn.Module):def __init__(self, config):super().__init__()self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)def forward(self, x):x = self.c_fc(x)x = F.relu(x).square() # ReLU²x = self.c_proj(x)return x
为什么用ReLU²而不是GELU/SwiGLU?
-
计算效率:ReLU²比GELU快约30%
-
训练稳定:不像GELU在训练初期可能不稳定
-
性能相当:在小模型上,性能差异<1%
这是典型的"简单就是美"——在不损失性能的前提下,选择最简单的实现。
2.3 权重初始化:Spectral Initialization
def _init_weights(self, module):if isinstance(module, nn.Linear):fan_out = module.weight.size(0)fan_in = module.weight.size(1)std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))torch.nn.init.normal_(module.weight, mean=0.0, std=std)
这个初始化策略来自论文"Spectral Initialization",核心思想:
-
基础方差:
1/√fan_in
(Xavier初始化) -
修正因子:
min(1.0, √(fan_out / fan_in))
-
当输出维度小于输入维度时,减小初始化方差
特殊处理:
# 投影层初始化为0(残差连接优化)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# 输出层初始化为0
torch.nn.init.zeros_(self.lm_head.weight)
这样做的好处:
-
训练初期残差路径主导,主路径逐渐学习
-
类似于"warm-up"的效果,但在权重层面实现
-
提升训练稳定性,减少早期loss震荡
三、Muon优化器:下一代训练加速器
3.1 为什么需要新的优化器?
在深度学习的历史长河中,优化器经历了多次革命:SGD → Momentum → Adam → AdamW。每次革新都带来了训练速度或效果的提升。但到了Transformer时代,我们发现Adam系列在训练大型语言模型时存在一些问题:
-
内存占用大:需要存储一阶和二阶动量,参数量翻倍
-
超参数敏感:lr、β1、β2、ε需要仔细调优
-
计算开销高:每步都要计算动量的指数滑动平均
Muon(Momentum Orthogonalized by Newton-schulz)优化器的出现,正是为了解决这些问题。
3.2 Muon的核心思想
Muon的设计哲学可以用一句话概括:在SGD-Momentum的基础上,通过正交化投影实现更快的收敛。
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5) -> Tensor:"""使用Newton-Schulz迭代计算矩阵的零次幂(正交化)输入:梯度矩阵 G输出:最接近的正交矩阵 ~UV^T(其中 USV^T = G 是SVD分解)"""a, b, c = (3.4445, -4.7750, 2.0315) # 五次迭代的优化系数X = G.bfloat16()# 如果行数>列数,转置以提高效率if G.size(-2) > G.size(-1):X = X.mT# 归一化谱范数到1X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)# Newton-Schulz迭代for _ in range(steps):A = X @ X.mTB = b * A + c * A @ A # 五次迭代X = a * X + B @ Xif G.size(-2) > G.size(-1):X = X.mTreturn X
这段代码看起来晦涩,但其实在做一件事:找到与梯度最接近的正交矩阵。
为什么要正交化?
-
避免梯度方向塌陷:正交矩阵保证更新方向在各个维度上均衡
-
加速收敛:正交更新等价于在参数空间做"最短路径"
-
数值稳定:正交矩阵的条件数为1,避免梯度爆炸/消失
Newton-Schulz迭代的魔法
传统计算矩阵正交化需要SVD分解,复杂度O(n³)且不稳定。Newton-Schulz方法:
-
复杂度:O(n²) × 5次迭代
-
数值稳定:在bfloat16下都能工作
-
可编译:用
@torch.compile
加速,接近手写CUDA性能
3.3 Muon优化器的完整实现
class Muon(torch.optim.Optimizer):def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)# 按参数大小分组(重要优化!)params = list(params)param_groups = []for size in {p.numel() for p in params}:group = dict(params=[p for p in params if p.numel() == size])param_groups.append(group)super().__init__(param_groups, defaults)@torch.no_grad()def step(self):for group in self.param_groups:for p in group["params"]:g = p.gradstate = self.state[p]# 初始化momentum bufferif "momentum_buffer" not in state:state["momentum_buffer"] = torch.zeros_like(g)# 标准Momentum更新buf = state["momentum_buffer"]buf.lerp_(g, 1 - group["momentum"])g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf# Muon的核心:正交化g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])# 应用更新(带aspect ratio缩放)scale = max(1, p.size(-2) / p.size(-1)) ** 0.5p.add_(g, alpha=-group["lr"] * scale)
几个关键设计:
1. 按大小分组(Batched Optimization)
for size in {p.numel() for p in params}:group = dict(params=[p for p in params if p.numel() == size])
相同大小的参数打包处理,可以:
-
利用批量矩阵运算(BLAS Level 3)
-
减少kernel启动开销
-
提高GPU利用率
2. Aspect Ratio缩放
scale = max(1, p.size(-2) / p.size(-1)) ** 0.5
这是Muon的一个subtle但重要的技巧:
-
矩阵越"瘦"(行多列少),学习率越大
-
补偿正交化在不同形状矩阵上的不均衡效应
-
实验表明能提升5-10%收敛速度
3. Nesterov动量
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
Nesterov动量提供"预见"效果:
-
先按momentum方向前进一步
-
在前进后的位置计算梯度
-
更准确地估计最优方向
3.4 分布式Muon:DistMuon
在多GPU训练中,Muon需要特殊处理以保证正确性和效率:
class DistMuon(torch.optim.Optimizer):@torch.no_grad()def step(self):rank = dist.get_rank()world_size = dist.get_world_size()# 1. Reduce-scatter:梯度求平均all_reduce_futures = []for group in self.param_groups:params = group["params"]for base_i in range(0, len(params), world_size):owner_idx = base_i + rank # 每个rank负责一部分参数rs_input = [p.grad for p in params[base_i:base_i + world_size]]rs_output = params[owner_idx].grad if owner_idx < len(params) else ...work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True)all_reduce_futures.append(work)# 2. 各rank独立更新自己负责的参数for future, param in zip(all_reduce_futures, owner_params):future.wait()# ... Muon更新逻辑 ...# 3. All-gather:同步更新后的参数all_gather_futures = []for base_i in range(0, len(params), world_size):ag_input = params[owner_idx] if owner_idx < len(params) else ...ag_output = params[base_i:base_i + world_size]work = dist.all_gather(ag_output, ag_input, async_op=True)all_gather_futures.append(work)# 等待所有通信完成torch.futures.collect_all(all_gather_futures).wait()
这个实现的精妙之处:
-
Block-cyclic分配:参数按world_size分块,每个rank负责一块,负载均衡
-
异步通信:reduce-scatter和all-gather异步进行,与计算overlap
-
内存高效:每个rank只存储部分momentum buffer,节省显存
性能对比:
-
相比Adam:收敛速度快20-30%
-
相比SGD:最终精度高2-3%
-
显存占用:与SGD相当(仅存一阶动量)
3.5 混合优化器策略
nanochat采用了一个聪明的策略:不同参数用不同优化器。
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02):# 矩阵参数(Attention + MLP)用Muonmatrix_params = list(self.transformer.h.parameters())muon_optimizer = Muon(matrix_params, lr=matrix_lr, momentum=0.95)# 嵌入层和输出层用AdamWembedding_params = list(self.transformer.wte.parameters())lm_head_params = list(self.lm_head.parameters())adam_groups = [dict(params=lm_head_params, lr=unembedding_lr),dict(params=embedding_params, lr=embedding_lr),]adamw_optimizer = AdamW(adam_groups, betas=(0.8, 0.95), eps=1e-10)return [adamw_optimizer, muon_optimizer]
为什么这样划分?
参数类型 | 优化器 | 学习率 | 理由 |
---|---|---|---|
Transformer矩阵 | Muon | 0.02 | 正交化加速收敛,适合密集矩阵 |
Token嵌入 | AdamW | 0.2 | 稀疏更新,Adam自适应学习率更稳定 |
输出层 | AdamW | 0.004 | 直接影响loss,需要保守更新 |
学习率比例:
-
嵌入层 : 矩阵层 : 输出层 = 50 : 5 : 1
-
嵌入层最高:因为每次只更新少量token的嵌入
-
输出层最低:避免训练后期loss震荡
dmodel缩放
dmodel_lr_scale = (model_dim / 768) ** -0.5
for group in adam_groups:group["lr"] *= dmodel_lr_scale
这个缩放因子来自μP(Maximal Update Parametrization)理论:
-
模型越宽,学习率应越小
-
缩放因子 ∝ 1/√d,保证不同宽度模型的"有效"学习率一致
-
便于从小模型的超参数迁移到大模型
四、高性能分词器:Rust + Python的完美融合
4.1 为什么分词器如此重要?
分词器是LLM的"第一道门",其设计直接影响:
-
压缩率:字符→token的转换效率,影响上下文长度和推理速度
-
泛化性:词表覆盖能力,决定了模型对未见过词汇的处理
-
性能:训练时每秒要处理数百万字符,分词速度至关重要
nanochat采用GPT-4风格的BPE(Byte Pair Encoding),但实现上做了两个大胆的选择:
-
训练用Rust:利用Rust的零成本抽象和并行计算能力
-
推理用tiktoken:OpenAI开源的高效C++实现,通过Python绑定使用
这种"两条腿走路"的策略充分发挥了各自优势。
4.2 GPT-4风格的文本切分
在应用BPE之前,需要先将文本切分成"块"(chunks)。GPT-4使用了一个精心设计的正则表达式:
SPLIT_PATTERN = r"""
'(?i:[sdmt]|ll|ve|re)| # 缩写:'s, 'm, 't, 'll, 've, 're
[^\r\n\p{L}\p{N}]?+\p{L}+| # 单词(可选前导非字母)
\p{N}{1,2}| # 数字(1-2位一组)?[^\s\p{L}\p{N}]++[\r\n]*| # 标点符号
\s*[\r\n]| # 换行
\s+(?!\S)| # 空格(后面不跟非空白)
\s+ # 其他空格
"""
设计考量:
-
缩写特殊处理:确保"don't"不会被拆成"don"+"'"+"t"
- 数字分组:1-2位一组(nanochat改动),而不是GPT-4的1-3位
-
理由:小词表场景下,节省token空间
-
缺点:大数字需要更多token表示
-
-
Unicode分类:使用
\p{L}
(字母)、\p{N}
(数字)支持多语言
4.3 Rust实现的高性能BPE训练
BPE算法的核心是贪心合并:
-
统计所有相邻token对的频率
-
找到频率最高的pair
-
合并这个pair成新token
-
重复直到达到目标词表大小
看似简单,但在百亿字符的数据上,计算量惊人。nanochat的Rust实现有几个巧妙优化:
4.3.1 数据结构设计
struct Word {ids: Vec<u32>, // token ID序列
}struct MergeJob {pair: (u32, u32), // 要合并的paircount: u64, // 频率pos: AHashSet<usize>, // 出现位置集合
}
关键点:
-
Word
只存储ID,不存原始字符串(节省内存) -
MergeJob
记录位置信息,避免全局扫描 -
使用
AHashSet
(ahash)而不是std::HashSet
,速度快30%
4.3.2 并行化策略
fn count_pairs_parallel(words: &[Word],counts: &[i32],
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {words.par_iter() // Rayon并行迭代.enumerate().map(|(i, w)| {// 每个线程独立统计let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();for (a, b) in w.pairs() {*local_pc.entry((a, b)).or_default() += counts[i];local_wtu.entry((a, b)).or_default().insert(i);}(local_pc, local_wtu)}).reduce(|| (AHashMap::new(), AHashMap::new()),|(mut acc_pc, mut acc_wtu), (pc, wtu)| {// 合并局部结果for (k, v) in pc { *acc_pc.entry(k).or_default() += v; }for (k, s) in wtu { acc_wtu.entry(k).or_default().extend(s); }(acc_pc, acc_wtu)},)
}
这是经典的map-reduce模式:
-
Map阶段:每个线程处理一部分words,统计local pair counts
-
Reduce阶段:合并所有线程的结果
性能提升:
-
单线程:~30分钟
-
8线程:~5分钟(5-6倍加速)
4.3.3 增量更新优化
传统BPE每次合并都重新统计全局pair counts,复杂度O(N²)。nanochat使用增量更新:
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {// 只记录局部变化let mut deltas: Vec<(Pair, i32)> = Vec::new();while i < n {if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {let left = out.last().copied();let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };// 受影响的pair:左邻、自己、右邻if let Some(x) = left {deltas.push(((x, a), -1)); // 移除deltas.push(((x, new_id), 1)); // 新增}deltas.push(((a, b), -1)); // 移除if let Some(y) = right {deltas.push(((b, y), -1));deltas.push(((new_id, y), 1));}out.push(new_id);i += 2;} else {out.push(self.ids[i]);i += 1;}}return deltas;
}
每次合并只产生O(1)个delta,全局更新变成:
for (pair, delta) in changes {*pair_counts.entry(pair).or_default() += delta * counts[word_idx];
}
复杂度从O(N²)降到O(N)!
4.3.4 堆优化
使用OctonaryHeap
(8叉堆)而不是二叉堆:
-
每次pop需要O(log₈ N) = 1/3 × O(log₂ N)次比较
-
虽然每层比较次数增加,但层数大幅减少
-
CPU缓存友好性更好
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() {heap.push(MergeJob { pair, count: ... });
}while merges_done < num_merges {let Some(mut top) = heap.pop() else { break; };// Lazy update:延迟刷新countlet current = *pair_counts.get(&top.pair).unwrap_or(&0);if top.count != current as u64 {top.count = current as u64;heap.push(top); // 重新入堆continue;}// ... 执行合并 ...
}
4.4 tiktoken推理
训练完成后,nanochat切换到tiktoken进行推理:
class RustBPETokenizer:def __init__(self, enc, bos_token):self.enc = enc # tiktoken.Encoding对象def encode(self, text, prepend=None, num_threads=8):if isinstance(text, str):ids = self.enc.encode_ordinary(text)elif isinstance(text, list):# 批量编码,自动并行ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)# ... 处理prepend/append ...return ids
tiktoken的优势:
-
C++实现:核心算法用C++编写,比纯Python快10-100倍
-
批量优化:自动并行处理batch,充分利用多核
-
缓存友好:使用hash trie存储merges,查找O(1)
性能对比(100K文档):
-
HuggingFace tokenizer:~45秒
-
tiktoken:~3秒(15倍加速)
4.5 对话格式渲染
对于SFT阶段,需要将对话转换为带特殊token的序列:
def render_conversation(self, conversation, max_tokens=2048):ids, mask = [], []# 特殊tokenbos = self.get_bos_token_id()user_start, user_end = ...assistant_start, assistant_end = ...python_start, python_end = ... # 工具使用# 渲染对话add_tokens(bos, mask=0)for message in conversation["messages"]:if message["role"] == "user":add_tokens(user_start, 0)add_tokens(self.encode(message["content"]), 0)add_tokens(user_end, 0)elif message["role"] == "assistant":add_tokens(assistant_start, 0)# 只有assistant的内容被mask=1(训练目标)add_tokens(self.encode(message["content"]), 1)add_tokens(assistant_end, 1)return ids[:max_tokens], mask[:max_tokens]
Mask机制:
-
mask=0
:不计算loss(prompt部分) -
mask=1
:计算loss(要学习的部分)
这样模型只学习生成assistant的回复,而不是重复用户的问题。
工具使用格式:
<|user_start|>计算123 + 456<|user_end|>
<|assistant_start|><|python_start|>123 + 456<|python_end|>
<|output_start|>579<|output_end|>
答案是579<|assistant_end|>
-
<|python_start|>...<|python_end|>
:模型生成的Python表达式(mask=1) -
<|output_start|>...<|output_end|>
:执行结果(mask=0,因为来自外部工具)
五、高效推理引擎:从理论到实践
5.1 KV Cache:推理加速的基石
在自回归生成中,每生成一个新token都要重新计算整个序列的注意力。假设序列长度为T:
-
第1个token:计算1个位置的attention
-
第2个token:计算2个位置的attention
-
第T个token:计算T个位置的attention
-
总计算量:O(T²)
这是巨大的浪费!因为前T-1个位置的Key和Value其实不会变。KV Cache的思想就是:缓存已计算的K和V,每次只计算新token的K和V。
5.1.1 KV Cache实现
class KVCache:def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):# 每层存储K和V:(num_layers, 2, B, H, T, D)self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)self.kv_cache = Noneself.pos = 0 # 当前填充到的位置def insert_kv(self, layer_idx, k, v):# 延迟初始化(知道dtype和device)if self.kv_cache is None:self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)B, H, T_add, D = k.size()t0, t1 = self.pos, self.pos + T_add# 动态扩容if t1 > self.kv_cache.size(4):t_needed = (t1 + 1024 + 1023) & ~1023 # 向上取整到1024的倍数current_shape = list(self.kv_cache.shape)current_shape[4] = t_neededself.kv_cache.resize_(current_shape)# 插入新的K和Vself.kv_cache[layer_idx, 0, :, :, t0:t1] = kself.kv_cache[layer_idx, 1, :, :, t0:t1] = v# 返回累积的K和V(view,无拷贝)key_view = self.kv_cache[layer_idx, 0, :, :, :t1]value_view = self.kv_cache[layer_idx, 1, :, :, :t1]# 最后一层更新posif layer_idx == self.kv_cache.size(0) - 1:self.pos = t1return key_view, value_view
设计亮点:
1. 延迟初始化(Lazy Initialization)
if self.kv_cache is None:self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
好处:
-
构造KVCache时不需要知道dtype和device
-
避免在meta device上创建tensor(用于模型初始化)
-
支持动态切换精度(fp32/bf16)
2. 动态扩容(Dynamic Resizing)
if t1 > self.kv_cache.size(4):t_needed = (t1 + 1024 + 1023) & ~1023 # 位运算向上取整
这段代码做了两件事:
-
增长1024的buffer(避免频繁扩容)
-
向上对齐到1024的倍数(GPU内存对齐优化)
例如:需要2050个位置 → 扩容到3072(2050+1024=3074 → 向上取整到3072)
3. Zero-copy视图
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
使用PyTorch的view机制,返回的是原tensor的slice,不会拷贝数据。这在长序列生成时节省大量时间。
5.1.2 Prefill优化
在batch生成时,常见场景是:
-
先用batch=1 prefill prompt(预填充)
-
然后复制KV cache到batch=N
-
并行生成N个样本
def prefill(self, other):"""从另一个KVCache预填充"""assert self.kv_cache is None, "只能预填充空cache"assert other.kv_cache is not None# 验证维度兼容性for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):if ix == 2: # batch维度可以扩展assert dim1 == dim2 or dim2 == 1elif ix == 4: # seq_len必须足够长assert dim1 >= dim2else: # 其他维度必须匹配assert dim1 == dim2# 初始化并拷贝dtype, device = other.kv_cache.dtype, other.kv_cache.deviceself.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cacheself.pos = other.pos
这样设计的好处:
-
Prompt只计算一次(节省计算)
-
支持batch>1的parallel sampling(提高吞吐)
-
代码复用性好(prefill和decode用同一套逻辑)
5.2 流式生成(Streaming Generation)
用户体验的关键在于"逐token显示"而非"等待全部完成"。nanochat的流式生成设计优雅:
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):# 1. Prefill:batch=1处理promptkv_cache_prefill = KVCache(batch_size=1, seq_len=len(tokens), ...)ids = torch.tensor([tokens], device=device)logits = self.model.forward(ids, kv_cache=kv_cache_prefill)next_ids = sample_next_token(logits[:, -1, :], rng, temperature, top_k)# 2. 复制KV cache到batch=num_sampleskv_cache_decode = KVCache(batch_size=num_samples, seq_len=..., ...)kv_cache_decode.prefill(kv_cache_prefill)# 3. 流式生成row_states = [RowState(tokens.copy()) for _ in range(num_samples)]while True:# ... 采样逻辑 ...# Yield一列tokens(每行一个)yield token_column, token_masks# 准备下一轮输入ids = torch.tensor(token_column, device=device).unsqueeze(1)
Generator设计:
-
使用Python生成器(
yield
),调用方可以逐token消费 -
返回
(token_column, token_masks)
,支持batch生成 -
token_masks
标记哪些token是采样的(1)哪些是强制的(0,用于工具调用)
调用示例:
engine = Engine(model, tokenizer)
for token_column, token_masks in engine.generate(prompt_tokens, num_samples=3, max_tokens=100):for i, token in enumerate(token_column):print(f"Sample {i}: {tokenizer.decode([token])}", end="", flush=True)
5.3 工具调用:Calculator Tool
nanochat实现了一个简单但实用的工具:计算器。模型可以主动调用计算器进行精确计算。
5.3.1 状态机设计
class RowState:def __init__(self, current_tokens=None):self.current_tokens = current_tokens or []self.forced_tokens = deque() # 强制插入的tokensself.in_python_block = False # 是否在python块内self.python_expr_tokens = [] # python表达式的tokensself.completed = False # 是否完成生成
每个生成样本维护一个状态机,跟踪:
-
当前生成到哪里
-
是否进入了
<|python_start|>
块 -
收集到的python表达式
-
待插入的工具输出tokens
5.3.2 工具调用流程
# 获取特殊tokens
python_start = tokenizer.encode_special("<|python_start|>")
python_end = tokenizer.encode_special("<|python_end|>")
output_start = tokenizer.encode_special("<|output_start|>")
output_end = tokenizer.encode_special("<|output_end|>")for token_column, token_masks in ...:for i, state in enumerate(row_states):next_token = token_column[i]state.current_tokens.append(next_token)if next_token == python_start:# 进入python块state.in_python_block = Truestate.python_expr_tokens = []elif next_token == python_end and state.in_python_block:# 退出python块,执行计算state.in_python_block = Falseexpr = tokenizer.decode(state.python_expr_tokens)result = use_calculator(expr) # 调用计算器if result is not None:# 将结果tokens强制插入生成序列result_tokens = tokenizer.encode(str(result))state.forced_tokens.append(output_start)state.forced_tokens.extend(result_tokens)state.forced_tokens.append(output_end)elif state.in_python_block:# 收集python表达式state.python_expr_tokens.append(next_token)
执行示例:
用户输入:
计算 123 * 456
模型生成:
<|assistant_start|>让我计算一下<|python_start|>123 * 456<|python_end|>
此时触发计算器:
expr = "123 * 456"
result = eval(expr) # 56088
强制插入:
<|output_start|>56088<|output_end|>
模型继续:
结果是56088<|assistant_end|>
5.3.3 安全执行
def use_calculator(expr):# 白名单检查if any([x not in "0123456789*+-/.() " for x in expr]):return None # 拒绝非数学字符if "**" in expr:return None # 拒绝幂运算(防止过大计算)# 超时保护return eval_with_timeout(expr, max_time=3)@contextmanager
def timeout(duration, formula):def timeout_handler(signum, frame):raise Exception(f"timed out after {duration} seconds")signal.signal(signal.SIGALRM, timeout_handler)signal.alarm(duration)yieldsignal.alarm(0)
安全措施:
-
白名单过滤:只允许数字和基本运算符
-
禁止危险操作:如
**
(幂运算)可能导致计算爆炸 -
超时保护:3秒内必须完成,否则终止
-
异常捕获:任何错误都返回None,不影响生成
5.4 采样策略
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):# Temperature = 0:贪心解码if temperature == 0.0:return torch.argmax(logits, dim=-1, keepdim=True)# Top-k采样if top_k is not None:k = min(top_k, logits.size(-1))vals, idx = torch.topk(logits, k, dim=-1)vals = vals / temperatureprobs = F.softmax(vals, dim=-1)choice = torch.multinomial(probs, num_samples=1, generator=rng)return idx.gather(1, choice)# 标准采样else:logits = logits / temperatureprobs = F.softmax(logits, dim=-1)return torch.multinomial(probs, num_samples=1, generator=rng)
Temperature的作用:
Temperature | 效果 | 适用场景 |
---|---|---|
0.0 | 贪心(argmax) | 需要确定性输出(代码生成、数学计算) |
0.5-0.7 | 低随机性 | 事实问答、摘要 |
0.8-1.0 | 平衡 | 通用对话 |
1.2-1.5 | 高随机性 | 创意写作、头脑风暴 |
Top-k的作用:
-
只从概率最高的k个token中采样
-
避免采样到低概率的"离谱"token
-
通常设置为50左右
组合策略:
# 平衡创意与质量
engine.generate(tokens, temperature=0.9, top_k=50)
六、训练流程:从零到可用的完整Pipeline
6.1 数据准备:FineWeb-Edu
nanochat使用FineWeb-Edu作为预训练数据集,这是HuggingFace精心筛选的高质量教育内容。
数据规模:
-
总大小:100B tokens(约480GB文本)
-
Shard数量:1822个parquet文件
-
每个shard:~250M字符(~100MB压缩)
下载策略:
def download_single_file(index):filename = f"shard_{index:05d}.parquet"url = f"{BASE_URL}/{filename}"# 增量下载(带重试)for attempt in range(1, 6):try:response = requests.get(url, stream=True, timeout=30)with open(temp_path, 'wb') as f:for chunk in response.iter_content(chunk_size=1MB):f.write(chunk)os.rename(temp_path, filepath)return Trueexcept Exception as e:wait_time = 2 ** attempttime.sleep(wait_time) # 指数退避
并行下载:
with Pool(processes=4) as pool:results = pool.map(download_single_file, ids_to_download)
4个进程并行,充分利用网络带宽。
数据量计算:
d20模型(561M参数)需要:
tokens_needed = params × 20 (Chinchilla)= 561M × 20= 11.2B tokens≈ 54B characters (假设4.8 char/token)≈ 216 shards (54B / 250M)
实际下载240个shards,留有余量。
6.2 流式DataLoader
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4):needed_tokens = B * T + 1tokenizer = get_tokenizer()bos_token = tokenizer.get_bos_token_id()token_buffer = deque()def document_batches():while True:for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):for i in range(0, len(batch), 128): # 分成小批yield batch[i:i+128]batches = document_batches()while True:# 填充buffer到足够大while len(token_buffer) < needed_tokens:doc_batch = next(batches)token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=4)for tokens in token_lists:token_buffer.extend(tokens)# 从buffer取出B*T+1个tokensfor i in range(needed_tokens):scratch[i] = token_buffer.popleft()# 构造inputs和targetsinputs = scratch[:-1].view(B, T).cuda()targets = scratch[1:].view(B, T).cuda()yield inputs, targets
设计亮点:
1. 流式处理
-
不把整个数据集加载到内存
-
逐个parquet文件读取,处理完即释放
-
支持无限epoch(while True)
2. 分布式友好
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
-
每个GPU处理不同的parquet文件
-
start=rank, step=world_size
实现数据并行 -
无需额外的分布式sampler
3. 异步分词
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=4)
-
分词在CPU上并行进行
-
GPU忙于前向/反向时,CPU在准备下一批数据
-
Overlap计算和数据准备
4. Pinned Memory
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
-
使用page-locked内存
-
CPU→GPU传输速度提升2-3倍
6.3 训练循环
for step in range(num_iterations):# ===== 评估 =====if step % eval_every == 0:val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)print(f"Validation bpb: {val_bpb:.4f}")if step % core_metric_every == 0:core_score = evaluate_model(model, tokenizer, device, max_per_task=500)print(f"CORE metric: {core_score:.4f}")# ===== 训练 =====for micro_step in range(grad_accum_steps):with autocast_ctx:loss = model(x, y)loss = loss / grad_accum_stepsloss.backward()x, y = next(train_loader) # 预取下一批# 梯度裁剪if grad_clip > 0.0:torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)# 学习率调度lrm = get_lr_multiplier(step)for opt in optimizers:for group in opt.param_groups:group["lr"] = group["initial_lr"] * lrm# 优化器stepfor opt in optimizers:opt.step()model.zero_grad(set_to_none=True)
关键技术:
1. 混合精度训练
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
-
前向/反向用bfloat16(节省显存+加速)
-
梯度累积用float32(保证精度)
-
自动转换,无需手动管理
2. 梯度累积
for micro_step in range(grad_accum_steps):loss = loss / grad_accum_stepsloss.backward()
-
模拟大batch训练
-
grad_accum_steps = total_batch_size // (device_batch_size * world_size)
-
每个micro-batch的loss要除以累积步数
3. 学习率调度
def get_lr_multiplier(it):warmup_iters = round(0.0 * num_iterations)warmdown_iters = round(0.2 * num_iterations)if it < warmup_iters:return (it + 1) / warmup_iters # Linear warmupelif it <= num_iterations - warmdown_iters:return 1.0 # 恒定学习率else:progress = (num_iterations - it) / warmdown_itersreturn progress * 1.0 + (1 - progress) * 0.0 # Linear warmdown
学习率曲线:
LR|
1.0| ┌────────────────────┐| │ └─┐| │ └─┐| │ └─┐
0.0|────┘ └────0% 80% 100% iters
warmdown(余弦衰减的简化版)能提升最终精度1-2%。
4. Momentum调度(仅Muon)
def get_muon_momentum(it):frac = min(it / 300, 1)return (1 - frac) * 0.85 + frac * 0.95
-
前300步:momentum从0.85升到0.95
-
类似"warm-up",让模型先探索再稳定
6.4 评估指标
6.4.1 BPB(Bits Per Byte)
def evaluate_bpb(model, val_loader, eval_steps, token_bytes):losses = []for _ in range(eval_steps):x, y = next(val_loader)with torch.no_grad():loss = model(x, y, loss_reduction='sum')losses.append(loss)total_loss = sum(losses)total_tokens = eval_steps * B * T# 计算BPBnll = total_loss / total_tokenstoken_bpb = nll / math.log(2) # nats → bits# 加权到字节级别bpb = (token_bpb * token_bytes).sum()return bpb
为什么用BPB而不是Perplexity?
-
语言无关:不同语言的perplexity不可比
-
更直观:BPB=1表示平均每字节1比特信息
-
可比较性强:可以和压缩算法(gzip等)对比
典型BPB值:
-
随机猜测:8.0 bpb
-
gzip压缩:2.5-3.5 bpb
-
GPT-2:0.9-1.0 bpb
-
GPT-3:0.7-0.8 bpb
-
nanochat d20:~1.2 bpb
6.4.2 CORE Metric
CORE是一个综合评测,包含1400道多选题,涵盖:
-
常识推理
-
科学知识
-
历史地理
-
数学逻辑
def evaluate_model(model, tokenizer, device, max_per_task=500):# 对每个问题,计算各选项的困惑度def eval_problem(problem):prompt = problem["prompt"]choices = problem["choices"]perplexities = []for choice in choices:full_text = prompt + choicetokens = tokenizer.encode(full_text, prepend="<|bos|>")with torch.no_grad():logits = model(tokens)loss = F.cross_entropy(logits[:-1], tokens[1:])perplexities.append(loss.item())# 困惑度最低的选项=模型预测predicted = np.argmin(perplexities)return predicted == problem["answer"]accuracies = [eval_problem(p) for p in problems[:max_per_task]]return np.mean(accuracies)
Centered Results:
centered_results = (accuracy - 0.25) / 0.75
-
随机猜测:25%准确率 → 0分
-
完美模型:100%准确率 → 1分
-
更公平地反映模型能力
nanochat d20的表现:
-
CORE: 0.22(原始0.42)
-
ARC-Easy: 0.36
-
ARC-Challenge: 0.29
-
MMLU: 0.31
-
HumanEval: 0.07
虽然不如大模型,但考虑到100美元的成本,已经相当impressive!
七、工程优化:榨干每一分算力
7.1 编译优化
model = torch.compile(model, dynamic=False)
PyTorch 2.x的killer feature:编译模型为优化的kernel。
加速来源:
-
Operator fusion:多个小op合并成一个大op
-
内存优化:减少中间tensor的分配
-
自动调优:Triton JIT编译,针对硬件优化
实测效果:
-
训练速度提升15-20%
-
推理速度提升30-40%
-
显存占用略有增加(编译开销)
注意事项:
model = torch.compile(model, dynamic=False)
-
dynamic=False
:假设输入shape固定,优化更激进 -
dynamic=True
:支持可变shape,但优化受限
SFT阶段用dynamic=True
因为每个batch的序列长度不同。
7.2 显存优化
7.2.1 激活检查点(Activation Checkpointing)
在训练超大模型时,激活值(中间层输出)是显存杀手。Activation Checkpointing的思想:
-
前向传播:只保存少数关键激活值
-
反向传播:重新计算被丢弃的激活值
PyTorch实现:
from torch.utils.checkpoint import checkpointclass Block(nn.Module):def forward(self, x, cos_sin, kv_cache):# 使用checkpoint包装x = x + checkpoint(self.attn, norm(x), cos_sin, kv_cache, use_reentrant=False)x = x + checkpoint(self.mlp, norm(x), use_reentrant=False)return x
Trade-off:
-
显存占用:减少40-50%
-
训练速度:降低10-15%(多一次前向)
nanochat默认不使用,因为d20模型还装得下。但对于d26(1.1B参数),建议启用。
7.2.2 梯度累积
前面提到过,再强调一次重要性:
# 小显存:device_batch_size=4, grad_accum_steps=64
# 大显存:device_batch_size=32, grad_accum_steps=8
# 效果完全一致!
这是在有限硬件上训练大模型的关键技术。
7.2.3 Expandable Segments
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
PyTorch的显存分配器优化:
-
传统方式:分配固定大小的block,碎片化严重
-
Expandable模式:动态扩展block,减少碎片
实测:OOM边界从28GB提升到30GB(同样模型)。
7.3 分布式训练
# 启动8卡训练
torchrun --standalone --nproc_per_node=8 -m scripts.base_train
7.3.1 DDP(Distributed Data Parallel)
def compute_init():ddp = "RANK" in os.environif ddp:dist.init_process_group(backend="nccl")ddp_rank = dist.get_rank()ddp_local_rank = int(os.environ["LOCAL_RANK"])ddp_world_size = dist.get_world_size()device = f"cuda:{ddp_local_rank}"else:ddp_rank, ddp_local_rank, ddp_world_size = 0, 0, 1device = "cuda"torch.cuda.set_device(device)return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
工作流程:
-
每个GPU独立前向+反向
-
反向结束后,all-reduce梯度
-
每个GPU用平均后的梯度更新参数
通信优化:
-
使用NCCL(NVIDIA Collective Communications Library)
-
支持gradient bucketing(分批通信)
-
与反向传播overlap(边计算边通信)
效率:
-
理论加速比:N(N卡)
-
实际加速比:0.9N(通信开销~10%)
-
nanochat实测:8卡加速7.2倍
7.3.2 DistMuon的精妙设计
前面提到过,这里再展开:
Block-cyclic分配:
owner_idx = base_i + rank
-
参数0, 8, 16...归rank 0
-
参数1, 9, 17...归rank 1
-
负载均衡,避免某个rank负担过重
三阶段通信:
-
Reduce-scatter:梯度求平均,每个rank得到一部分
-
Local update:各rank独立更新自己的参数
-
All-gather:广播更新后的参数
相比all-reduce:
-
通信量相同
-
但可以overlap更多计算
-
显存占用更低(只在owner上存momentum)
7.4 MFU(Model FLOPs Utilization)
MFU是衡量训练效率的金标准:
promised_flops = 989e12 * ddp_world_size # H100 SXM的理论FLOPs
actual_flops = num_flops_per_token * total_batch_size / dt
mfu = actual_flops / promised_flops
FLOPs估算:
def estimate_flops(self):nparams = sum(p.numel() for p in self.parameters())nparams_embedding = self.transformer.wte.weight.numel()l, h, q, t = self.config.n_layer, self.config.n_head, ...# 6N:前向+反向的矩阵运算# 12lhqt:注意力的额外计算num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * treturn num_flops_per_token
nanochat d20的MFU:
-
单卡A100:~35% MFU
-
单卡H100:~40% MFU
-
8卡H100:~38% MFU
对比:
-
GPT-3训练:~20% MFU(2020年)
-
PaLM训练:~46% MFU(2022年)
-
LLaMA训练:~55% MFU(2023年)
nanochat虽不及SOTA,但对于<10K行代码的项目,已经很优秀。提升空间:
-
Flash Attention 3(预计+5%)
-
自定义fused kernels(+10%)
-
更激进的operator fusion(+5%)
八、Web服务:从训练到生产的最后一公里
8.1 FastAPI架构
nanochat使用FastAPI构建了一个ChatGPT风格的Web服务,代码极其简洁:
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel@asynccontextmanager
async def lifespan(app: FastAPI):"""在启动时加载模型"""print("Loading nanochat model...")app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval")app.state.engine = Engine(app.state.model, app.state.tokenizer)print(f"Server ready at http://localhost:{args.port}")yieldapp = FastAPI(lifespan=lifespan)@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):"""Chat completion endpoint with streaming"""engine = app.state.enginetokenizer = app.state.tokenizer# 渲染对话为token序列conversation_tokens = render_conversation_to_tokens(request.messages)# 流式生成if request.stream:return StreamingResponse(generate_stream(engine, tokenizer, conversation_tokens, ...),media_type="text/event-stream")else:# 非流式生成result_tokens, _ = engine.generate_batch(conversation_tokens, ...)return {"choices": [{"message": {"role": "assistant", "content": ...}}]}
关键设计:
1. Lifespan管理
@asynccontextmanager
async def lifespan(app: FastAPI):# 启动时:加载模型app.state.model = load_model(...)yield# 关闭时:清理资源(可选)
-
模型只加载一次,所有请求共享
-
避免每次请求都重新加载模型(太慢!)
-
支持优雅关闭
2. Server-Sent Events (SSE)
async def generate_stream(...) -> AsyncGenerator[str, None]:for token_column, token_masks in engine.generate(...):token = token_column[0]token_text = tokenizer.decode([token])yield f"data: {json.dumps({'token': token_text})}\n\n"yield f"data: {json.dumps({'done': True})}\n\n"
SSE格式:
data: {"token": "你"}data: {"token": "好"}data: {"token": "!"}data: {"done": true}
浏览器端接收:
const eventSource = new EventSource('/chat/completions');
eventSource.onmessage = (event) => {const data = JSON.parse(event.data);if (data.done) {eventSource.close();} else {displayToken(data.token);}
};
3. 跨域支持
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)
允许从任何域访问(生产环境应限制origins)。
8.2 前端UI
nanochat自带一个优雅的Web界面(单文件HTML+JavaScript):
<!DOCTYPE html>
<html>
<head><style>/* ChatGPT风格的样式 */.message-user { background: #f0f0f0; }.message-assistant { background: white; }</style>
</head>
<body><div id="chat-container"></div><input id="user-input" type="text" placeholder="Send a message..."><script>async function sendMessage(message) {const response = await fetch('/chat/completions', {method: 'POST',headers: {'Content-Type': 'application/json'},body: JSON.stringify({messages: [...conversationHistory, {role: 'user', content: message}],stream: true})});const reader = response.body.getReader();const decoder = new TextDecoder();while (true) {const {done, value} = await reader.read();if (done) break;const chunk = decoder.decode(value);const lines = chunk.split('\n');for (const line of lines) {if (line.startsWith('data: ')) {const data = JSON.parse(line.slice(6));if (data.token) {appendToLastMessage(data.token);}}}}}</script>
</body>
</html>
特性:
-
流式显示:逐token渲染,体验流畅
-
Markdown支持:代码块、列表、链接等
-
对话历史:多轮对话上下文管理
-
响应式设计:适配桌面和移动端
8.3 性能优化
8.3.1 批量推理
虽然Web服务是单用户场景,但仍可用批量优化:
# 在prefill阶段,batch=1
kv_cache_prefill = KVCache(batch_size=1, ...)
logits = model.forward(prompt_tokens, kv_cache=kv_cache_prefill)# 在decode阶段,可以并行生成多个候选
kv_cache_decode = KVCache(batch_size=5, ...) # 5个候选
kv_cache_decode.prefill(kv_cache_prefill)# 生成5个候选,选最优的返回
candidates, _ = engine.generate_batch(prompt_tokens, num_samples=5, ...)
best_candidate = select_best(candidates) # 可以用reward model打分
Best-of-N采样:
-
生成N个候选回复
-
用reward model或启发式规则选最优
-
质量提升明显,但推理成本增加N倍
8.3.2 投机解码(Speculative Decoding)
这是一个前沿优化技术(nanochat未实现,但值得一提):
# 用小模型(fast)猜测接下来的k个tokens
draft_tokens = small_model.generate(prompt, max_tokens=k)# 用大模型(slow)并行验证这k个tokens
logits = large_model.forward(torch.cat([prompt, draft_tokens]))
acceptance = verify_tokens(logits, draft_tokens)# 接受正确的tokens,拒绝错误的
accepted_count = acceptance.sum()
result = draft_tokens[:accepted_count]
理论加速比:2-3倍(取决于小模型的准确率)。
8.4 部署建议
云平台选择:
平台 | GPU | 价格 | 适用场景 |
---|---|---|---|
Lambda Labs | H100 | $2.49/h | 训练(性价比高) |
RunPod | A40 | $0.79/h | 推理(便宜) |
Vast.ai | V100 | $0.20/h | 开发调试 |
AWS | A100 | $4.10/h | 生产环境(稳定) |
推理优化:
# 使用bfloat16推理(速度快,精度损失小)
model = model.bfloat16()# 启用编译优化
model = torch.compile(model, mode="reduce-overhead")# 增大batch size(延迟换吞吐)
engine.generate_batch(..., num_samples=8)
监控指标:
-
延迟(Latency):首token时间(TTFT)、平均token时间
-
吞吐(Throughput):tokens/秒
-
资源利用率:GPU利用率、显存占用
-
可用性:请求成功率、错误率
九、实战应用:从玩具到生产
9.1 垂直领域模型
nanochat的最大价值在于可定制性。几个实战方向:
9.1.1 法律助手
# 1. 在法律语料上继续预训练(domain adaptation)
legal_corpus = load_dataset("legal_cases", "judgments", "laws")
model = load_model("base", device, phase="train")
train(model, legal_corpus, num_iterations=5000)# 2. 在法律QA数据上SFT
legal_qa = load_dataset("legal_qa")
sft_train(model, legal_qa, num_iterations=1000)# 3. 部署为法律咨询服务
app = FastAPI()
@app.post("/legal_advice")
async def legal_advice(question: str):prompt = f"作为专业律师,请回答以下法律问题:\n{question}"response = engine.generate(tokenizer.encode(prompt), ...)return {"advice": tokenizer.decode(response)}
关键点:
-
Domain adaptation很重要:法律术语、判例引用等
-
数据质量>数量:1000条高质量QA胜过10000条低质量
-
需要免责声明:AI建议仅供参考
9.1.2 代码助手
# 在GitHub代码上预训练
code_corpus = load_dataset("codeparrot/github-code")
model = load_model("base", device, phase="train")
train(model, code_corpus, num_iterations=10000)# 在code completion任务上微调
humaneval = load_dataset("openai_humaneval")
mbpp = load_dataset("mbpp")
sft_train(model, humaneval + mbpp, num_iterations=500)# 集成到VS Code
def code_completion(prefix, suffix):prompt = f"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"completion = engine.generate(tokenizer.encode(prompt), temperature=0.2, top_k=50)return tokenizer.decode(completion)
性能提升:
-
在HumanEval上从7%提升到15-20%(仅需2-3小时微调)
-
Pass@10(生成10个候选)可达30-40%
9.1.3 客服机器人
# 在企业内部FAQ上微调
faq_data = [{"user": "如何退款?", "assistant": "退款流程是..."},{"user": "订单状态查询", "assistant": "请提供订单号..."},...
]# 添加工具调用(查询订单系统)
def query_order(order_id):return database.get_order(order_id)# 在生成时调用工具
conversation = [{"role": "user", "content": "订单12345的状态"},{"role": "assistant", "content": [{"type": "text", "text": "让我查一下"},{"type": "python", "text": f"query_order('12345')"},{"type": "python_output", "text": "{'status': 'shipped', 'eta': '2025-10-20'}"},{"type": "text", "text": "您的订单已发货,预计10月20日送达"},]}
]
优势:
-
成本低:相比调用GPT-4 API节省90%+
-
低延迟:本地部署,<100ms首token
-
数据隐私:敏感信息不出企业内网
9.2 研究方向
nanochat也是绝佳的研究平台:
9.2.1 数据效率研究
问题:如何用更少数据训练出更好的模型?
# 实验1:Curriculum learning(课程学习)
easy_data = filter_by_difficulty(all_data, difficulty="easy")
hard_data = filter_by_difficulty(all_data, difficulty="hard")train(model, easy_data, num_iterations=5000) # 先学简单的
train(model, hard_data, num_iterations=5000) # 再学难的# 实验2:Data pruning(数据剪枝)
scores = [score_quality(doc) for doc in all_data]
top_data = [doc for doc, score in zip(all_data, scores) if score > threshold]train(model, top_data, num_iterations=10000) # 只用高质量数据
预期发现:
-
Curriculum learning可能提升5-10%效果
-
Data pruning可以用50%数据达到80%效果
9.2.2 架构探索
问题:哪些架构改动在小模型上有效?
# 实验1:不同激活函数
class MLPWithGELU(nn.Module):def forward(self, x):return self.c_proj(F.gelu(self.c_fc(x)))class MLPWithSwiGLU(nn.Module):def forward(self, x):gate, up = self.c_fc(x).chunk(2, dim=-1)return self.c_proj(F.silu(gate) * up)# 实验2:不同attention变体
class SlidingWindowAttention(nn.Module):def forward(self, q, k, v):# 只attend到最近w个tokensattn_mask = get_sliding_window_mask(window_size=512)return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
方法论:
-
控制变量:只改一个超参数
-
多次实验:随机种子不同,跑3-5次取平均
-
评估全面:不只看loss,还要看downstream任务
9.2.3 优化算法研究
问题:Muon之外还有更好的优化器吗?
# 实验1:混合精度的最佳实践
configs = [{"forward": "bf16", "backward": "fp32", "optimizer": "fp32"},{"forward": "fp16", "backward": "fp32", "optimizer": "fp32"},{"forward": "fp8", "backward": "fp16", "optimizer": "fp32"}, # 未来的FP8
]# 实验2:学习率调度
schedulers = ["linear_warmdown","cosine_annealing","inverse_sqrt","constant",
]# 实验3:Batch size vs Learning rate
for batch_size in [128k, 256k, 512k, 1M]:for lr_scale in [0.5, 1.0, 2.0]:train(model, data, batch_size=batch_size, lr=base_lr * lr_scale)
9.3 教学应用
nanochat非常适合作为教学材料:
9.3.1 大学课程
课程大纲(4周):
Week 1:Transformer基础
-
阅读
gpt.py
,理解self-attention、MLP、LayerNorm -
作业:手写一个mini-transformer(100行)
-
实验:训练一个character-level language model
Week 2:高效训练
-
学习mixed precision、gradient accumulation、DDP
-
阅读
base_train.py
,理解训练循环 -
作业:在小数据集上复现训练
Week 3:分词与数据
-
理解BPE算法,阅读
rustbpe/src/lib.rs
-
学习数据流pipeline,阅读
dataloader.py
-
作业:训练自己的分词器
Week 4:推理与部署
-
学习KV cache、sampling策略
-
阅读
engine.py
和chat_web.py
-
作业:部署一个Web服务
9.3.2 在线教程
制作step-by-step教程:
# nanochat从零开始## Part 1: 环境搭建(15分钟)
\```bash
git clone https://github.com/karpathy/nanochat
cd nanochat
uv sync
\```## Part 2: 训练玩具模型(30分钟)
\```bash
# 下载1个shard(100MB)
python -m nanochat.dataset -n 1# 训练小tokenizer(10K词表)
python -m scripts.tok_train --vocab_size=10000 --max_chars=100000000# 训练tiny模型(d4, 22M参数)
python -m scripts.base_train -- --depth=4 --num_iterations=100
\```## Part 3: 对话(10分钟)
\```bash
python -m scripts.chat_web
# 访问 http://localhost:8000
\```## 思考题
1. 为什么用ReLU²而不是GELU?
2. Muon相比Adam的优势是什么?
3. 如何减少模型的推理延迟?
十、未来展望:小模型的星辰大海
10.1 技术演进方向
10.1.1 量化与压缩
当前状态:nanochat用bfloat16训练和推理
未来方向:
# INT8量化:无损精度,速度提升2倍
from torch.quantization import quantize_dynamic
model_int8 = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)# INT4量化:轻微精度损失,速度提升4倍
from transformers import BitsAndBytesConfig
model_int4 = load_model(..., quantization_config=BitsAndBytesConfig(load_in_4bit=True))# 混合精度推理:重要层用高精度,其他层用低精度
for name, module in model.named_modules():if "mlp" in name:module = module.to(dtype=torch.int8) # MLP用INT8else:module = module.to(dtype=torch.bfloat16) # Attention用BF16
预期收益:
-
INT8:速度+100%,精度-1%
-
INT4:速度+300%,精度-3-5%
-
混合精度:速度+150%,精度-2%
10.1.2 稀疏化(Sparsity)
观察:Transformer的权重矩阵很多元素接近0
方法:
# 结构化稀疏(2:4 sparsity)
# 每4个元素中至少2个为0
def structured_prune(weight, ratio=0.5):# 每4个元素一组w = weight.view(-1, 4)# 保留每组最大的2个topk_vals, topk_idx = torch.topk(w.abs(), k=2, dim=-1)mask = torch.zeros_like(w)mask.scatter_(-1, topk_idx, 1)return weight * mask.view_as(weight)# 应用到模型
for module in model.modules():if isinstance(module, nn.Linear):module.weight.data = structured_prune(module.weight.data)
H100的2:4稀疏加速:
-
理论加速:2倍(减少50%计算)
-
实际加速:1.6倍(内存带宽限制)
-
精度损失:<2%(通过sparse-aware training补偿)
10.1.3 长文本支持
当前限制:2048 tokens(约6000字)
扩展方法:
方法1:位置插值(Position Interpolation)
# 训练时:seq_len=2048
cos, sin = precompute_rotary(seq_len=2048, base=10000)# 推理时:seq_len=8192,插值缩放
cos, sin = precompute_rotary(seq_len=8192, base=10000 * 4)
无需重新训练,外推到4倍长度。
方法2:Attention Sink
# 保留前k个tokens的attention(作为"sink")
def attention_with_sink(q, k, v, sink_size=4):# 前sink_size个tokens总是被attend# 后续只attend到滑动窗口...
支持无限长度,精度损失小。
方法3:分层注意力(Hierarchical Attention)
# 低层:local attention(窗口512)
# 中层:strided attention(步长4)
# 高层:global attention(全部)
复杂度从O(n²)降到O(n log n)。
10.2 应用场景拓展
10.2.1 边缘设备部署
目标:在手机/树莓派上运行nanochat
方案:
-
模型压缩:量化到INT4,剪枝到50%稀疏
-
架构优化:减少层数(d20→d12),缩小宽度
-
推理框架:llama.cpp、GGML等C++实现
-
结果:100M参数,<500MB内存,<1s首token
应用:
-
离线翻译助手
-
本地笔记整理
-
隐私优先的个人助理
10.2.2 多模态扩展
文本+图像:
# 添加视觉编码器
class VisionEncoder(nn.Module):def __init__(self):self.vit = VisionTransformer(...)self.projector = nn.Linear(vision_dim, text_dim)def forward(self, image):features = self.vit(image)return self.projector(features) # 投影到文本空间# 融合到GPT
class MultimodalGPT(GPT):def forward(self, text_tokens=None, image=None):if image is not None:image_features = self.vision_encoder(image)# 拼接图像特征和文本tokensx = torch.cat([image_features, self.embed(text_tokens)], dim=1)else:x = self.embed(text_tokens)# ... 正常Transformer处理 ...
文本+音频:
class AudioEncoder(nn.Module):def __init__(self):self.whisper = WhisperEncoder(...) # 音频特征提取self.projector = nn.Linear(audio_dim, text_dim)
10.2.3 联邦学习(Federated Learning)
场景:多个医院想联合训练医疗模型,但不能共享病历
方案:
# 中心服务器
global_model = GPT(...)for round in range(num_rounds):# 1. 分发模型到各医院for hospital in hospitals:hospital.receive_model(global_model)# 2. 各医院独立训练local_updates = []for hospital in hospitals:local_model = hospital.train_local(num_steps=100)local_updates.append(local_model.state_dict())# 3. 聚合更新(联邦平均)global_state = global_model.state_dict()for key in global_state:global_state[key] = sum([u[key] for u in local_updates]) / len(local_updates)global_model.load_state_dict(global_state)
nanochat的优势:
-
模型小,通信开销低
-
训练快,每轮只需几分钟
-
代码简单,易于审计和信任
10.3 社区与生态
nanochat已经形成了活跃的社区:
贡献方向:
-
新任务评测:添加更多benchmark(GLUE、SuperGLUE等)
-
优化技巧:Flash Attention 3、Paged Attention等
-
工具集成:Weights & Biases、MLflow等
-
多语言支持:中文、日文等非英语模型
-
教程文档:视频教程、交互式notebook
Fork衍生项目:
-
nanochat-medical
:医疗领域模型 -
nanochat-code
:代码生成专用 -
nanochat-zh
:中文优化版本 -
nanochat-tiny
:<100M参数的超小模型
十一、总结:极简主义的哲学
回顾整个nanochat项目,最打动我的不是某个具体技术,而是贯穿始终的极简主义哲学:
11.1 Less is More(少即是多)
在一个充斥着"大力出奇迹"的时代,nanochat逆流而上:
-
不用10万行代码,只用8千行
-
不花1000万美元,只花100美元
-
不追求SOTA性能,只追求可理解性
这种克制带来了:
-
更低的认知负担:任何人都能在一周内掌握
-
更高的灵活性:想改就改,没有历史包袱
-
更快的迭代速度:从想法到验证,只需几小时
11.2 Simplicity is Sophistication(简单即精致)
nanochat的简单不是简陋,而是深思熟虑的结果:
-
选择ReLU²而非GELU:深思熟虑的简化
-
使用RMSNorm无参数:化繁为简的智慧
-
Muon优化器:在理论深度和实现简洁间取得平衡
-
双实现分词器:在训练和推理间找到最优解
每一行代码都经过精心打磨,没有冗余,没有炫技。
11.3 Hackable is Powerful(可黑客化即强大)
nanochat最大的价值不是产出一个模型,而是赋能每个人:
-
研究者:快速验证新想法
-
工程师:学习工业级实践
-
创业者:低成本构建垂直模型
-
学生:理解LLM工作原理
这种"授人以渔"的理念,远比"授人以鱼"更有意义。
11.4 Personal Reflection(个人感悟)
作为一个深度学习从业者,看完nanochat的代码后我深受震撼。
在过去几年,我们见证了模型越来越大、代码越来越复杂、门槛越来越高。很多人(包括我)开始怀疑:普通开发者还有机会吗?
nanochat给出了响亮的答案:有!
你不需要数百张A100,不需要精通CUDA编程,不需要读遍所有论文。你只需要:
-
一台带GPU的电脑(或租一台)
-
扎实的PyTorch基础
-
对LLM的好奇心
100美元和一个周末,你就能训练出属于自己的ChatGPT。虽然它不会超越GPT-4,但它是真正属于你的——你理解每一行代码,你掌握每个超参数,你可以随心所欲地改造它。
这种掌控感,是任何API调用无法给予的。
11.5 Call to Action(行动号召)
如果你读到这里,我强烈建议你:
1. 克隆项目,跑起来
git clone https://github.com/karpathy/nanochat
cd nanochat
bash speedrun.sh
2. 阅读核心代码 按顺序读:gpt.py
→ engine.py
→ tokenizer.py
→ base_train.py
3. 做一个小实验
-
换个激活函数?
-
改个学习率调度?
-
加个新的评测任务?
4. 分享你的发现
-
写博客记录实验
-
在GitHub提PR
-
在社区讨论心得
从之前的micrograd、nanoGPT到现在的nanochat,他一直在践行"教育优先、简洁优先"的理念。在一个追求论文数量和引用量的学术环境中,他选择了一条更难但更有意义的路——让每个人都能理解AI。
这种精神值得我们每个人学习。
参考资源
官方资源:
-
GitHub仓库:https://github.com/karpathy/nanochat
-
作者讲解:https://github.com/karpathy/nanochat/discussions/1
-
课程LLM101n:https://github.com/karpathy/LLM101n
相关论文:
-
"Attention Is All You Need" (Transformer原论文)
-
"Chinchilla: Training Compute-Optimal Large Language Models"
-
"Muon: MomentUm Orthogonalized by Newton-schulz"
-
"RoFormer: Enhanced Transformer with Rotary Position Embedding"
延伸阅读:
-
nanoGPT:Transformer预训练的极简实现
-
modded-nanoGPT:优化版nanoGPT,很多技巧被nanochat采用
-
llm.c:纯C实现的LLM训练,终极性能
社区资源:
-
Discord讨论组
-
Reddit r/MachineLearning
-
Twitter #nanochat
更多AIGC文章