LM实现教程:基于 nanochat项目 从零开始理解大语言模型
LLM实现教程:从零开始理解大语言模型
基于 nanochat 仓库的深入解析
引言
nanochat 是一个完整的LLM实现,成本约100美元即可训练的ChatGPT类模型。本教程将带你深入理解:
- LLM的基本原理
- Transformer架构的核心实现
- 完整的训练流程(预训练、微调、强化学习)
- 各种优化技术和分布式训练
为何选择nanochat?
- 代码精简:约8K行代码,易于理解
- 端到端:从分词到部署的完整流程
- 约100美元:成本低
- 可配置性高:代码结构清晰,易于修改
LLM基础知识
什么是LLM?
大语言模型(Large Language Model, LLM)是一类基于Transformer架构的深度学习模型,通过学习大量文本数据来理解和生成自然语言。
核心概念
1. Token与Tokenization(分词)
为什么需要分词?
神经网络无法直接处理文本,需要将文本转换为数字序列。
BPE(Byte Pair Encoding) 是主流方法:
"""
BPE Tokenizer in the style of GPT-4.Two implementations are available:
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
"""import os
import copy
from functools import lru_cacheSPECIAL_TOKENS = [# every document begins with the Beginning of Sequence (BOS) token that delimits documents"<|bos|>",# tokens below are only used during finetuning to render Conversations into token ids"<|user_start|>", # user messages"<|user_end|>","<|assistant_start|>", # assistant messages"<|assistant_end|>","<|python_start|>", # assistant invokes python REPL tool"<|python_end|>","<|output_start|>", # python REPL outputs back to assistant"<|output_end|>",
]
工作原理:
- 从单个字符开始
- 统计最常出现的相邻符号对
- 合并成新符号
- 重复上述过程直到达到目标词汇表大小
nanochat的实现:
class RustBPETokenizer:"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""def __init__(self, enc, bos_token):self.enc = encself.bos_token_id = self.encode_special(bos_token)@classmethoddef train_from_iterator(cls, text_iterator, vocab_size):# 1) train using rustbpetokenizer = rustbpe.Tokenizer()# the special tokens are inserted later in __init__, we don't train them herevocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)# 2) construct the associated tiktoken encoding for inferencepattern = tokenizer.get_pattern()mergeable_ranks_list = tokenizer.get_mergeable_ranks()mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}tokens_offset = len(mergeable_ranks)special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}enc = tiktoken.Encoding(name="rustbpe",pat_str=pattern,mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)special_tokens=special_tokens, # dict[str, int] (special token name -> token id))return cls(enc, "<|bos|>")
2. Transformer架构
核心组件:
- Embeddings(嵌入层):将token id转换为向量
- Attention(注意力机制):让模型关注相关的上下文
- Feed-Forward(前馈网络):非线性变换
- Layer Norm(层归一化):稳定训练
class GPT(nn.Module):def __init__(self, config):super().__init__()self.config = configself.transformer = nn.ModuleDict({"wte": nn.Embedding(config.vocab_size, config.n_embd),"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),})self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)# To support meta device initialization, we init the rotary embeddings here, but it's fake# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,# so let's just over-compute them, but assert fail if we ever reach that amount.# In the future we can dynamically grow the cache, for now it's fine.self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?head_dim = config.n_embd // config.n_headcos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpointself.register_buffer("sin", sin, persistent=False)
3. 注意力机制(Attention)
**自注意力(Self-Attention)**让模型理解序列内部关系:
class CausalSelfAttention(nn.Module):def __init__(self, config, layer_idx):super().__init__()self.layer_idx = layer_idxself.n_head = config.n_headself.n_kv_head = config.n_kv_headself.n_embd = config.n_embdself.head_dim = self.n_embd // self.n_headassert self.n_embd % self.n_head == 0assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)def forward(self, x, cos_sin, kv_cache):B, T, C = x.size()# Project the input to get queries, keys, and valuesq = self.c_q(x).view(B, T, self.n_head, self.head_dim)k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)# Apply Rotary Embeddings to queries and keys to get relative positional encodingcos, sin = cos_sinq, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embeddingq, k = norm(q), norm(k) # QK normq, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)# Apply KV cache: insert current k,v into cache, get the full view so farif kv_cache is not None:k, v = kv_cache.insert_kv(self.layer_idx, k, v)Tq = q.size(2) # number of queries in this forward passTk = k.size(2) # number of keys/values in total (in the cache + current forward pass)# Attention: queries attend to keys/values autoregressively. A few cases to handle:enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desiredif kv_cache is None or Tq == Tk:# During training (no KV cache), attend as usual with causal attention# And even if there is KV cache, we can still use this simple version when Tq == Tky = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)elif Tq == 1:# During inference but with a single query in this forward pass:# The query has to attend to all the keys/values in the cachey = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)else:# During inference AND we have a chunk of queries in this forward pass:# First, each query attends to all the cached keys/values (i.e. full prefix)attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = maskprefix_len = Tk - Tqif prefix_len > 0: # can't be negative but could be zeroattn_mask[:, :prefix_len] = True# Then, causal attention within this chunkattn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)# Re-assemble the heads side by side and project back to residual streamy = y.transpose(1, 2).contiguous().view(B, T, -1)y = self.c_proj(y)return y
关键点:
- Q(Query):查询向量,寻找信息
- K(Key):键向量,提供信息位置
- V(Value):值向量,实际信息内容
- 因果掩码(Causal Mask):确保只能看到当前位置之前的token
4. 位置编码
Rotary Position Embedding (RoPE):
def apply_rotary_emb(x, cos, sin):assert x.ndim == 4 # multihead attentiond = x.shape[3] // 2x1, x2 = x[..., :d], x[..., d:] # split up last time into two halvesy1 = x1 * cos + x2 * sin # rotate pairs of dimsy2 = x1 * (-sin) + x2 * cosout = torch.cat([y1, y2], 3) # re-assembleout = out.to(x.dtype) # ensure input/output dtypes matchreturn out
RoPE通过旋转向量的方式编码位置信息,比传统的绝对位置嵌入更优雅。
nanochat架构概览
整体流程
1. 数据下载 → 2. 训练分词器 → 3. 预训练 → 4. 中训练 → 5. SFT微调 → 6. 评估/部署
目录结构
nanochat/
├── nanochat/ # 核心代码
│ ├── gpt.py # GPT模型定义
│ ├── engine.py # 推理引擎(KV Cache)
│ ├── tokenizer.py # BPE分词器
│ ├── dataloader.py # 数据加载
│ ├── adamw.py # AdamW优化器
│ ├── muon.py # Muon优化器
│ └── execution.py # Python代码执行工具
├── scripts/ # 训练脚本
│ ├── base_train.py # 预训练
│ ├── mid_train.py # 中训练
│ ├── chat_sft.py # SFT训练
│ └── chat_rl.py # RL训练
└── tasks/ # 评估任务
核心组件详解
1. GPT模型
模型配置:
@dataclass
class GPTConfig:sequence_len: int = 1024vocab_size: int = 50304n_layer: int = 12n_head: int = 6 # number of query headsn_kv_head: int = 6 # number of key/value heads (MQA)n_embd: int = 768
Transformer Block:
class Block(nn.Module):def __init__(self, config, layer_idx):super().__init__()self.attn = CausalSelfAttention(config, layer_idx)self.mlp = MLP(config)def forward(self, x, cos_sin, kv_cache):x = x + self.attn(norm(x), cos_sin, kv_cache)x = x + self.mlp(norm(x))return x
前馈网络:
class MLP(nn.Module):def __init__(self, config):super().__init__()self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)def forward(self, x):x = self.c_fc(x)x = F.relu(x).square()x = self.c_proj(x)return x
关键特性:
- 无偏置:所有线性层不使用偏置
- relu^2 激活:
relu(x)²,较标准ReLU表现更好 - 残差连接:
x = x + f(x) - 层归一化:使用RMSNorm
- 嵌入与输出头不共享权重:untied weights(
wte与lm_head独立)
2. 推理引擎(KV Cache)
为什么需要KV Cache?
在推理时,之前计算的key-value可以缓存,避免重复计算:
class KVCache:"""Works hand-in-hand with the GPT model to maintain the KV cache.Note that the .pos advances automatically after the last layer of the Transformer inserts."""def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)self.kv_cache = Noneself.pos = 0 # current position in time in the cachedef reset(self):self.pos = 0def get_pos(self):return self.posdef prefill(self, other):"""Prefill given another KV cache. Optionally expand along batch dim.This is used when we do batch 1 prefill and then want to generatemultiple samples in parallel from there."""# 1) validate the shapesassert self.kv_cache is None, "Cannot prefill a non-empty KV cache"assert other.kv_cache is not None, "Cannot prefill with a None KV cache"for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):if ix in [0, 1, 3, 5]:# num_layers, batch_size, num_heads, head_dim must matchassert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"elif ix == 2:# batch_size can be expandedassert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"elif ix == 4:# seq_len: self must be longer than otherassert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"# 2) initialize the cachedtype, device = other.kv_cache.dtype, other.kv_cache.deviceself.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)# 3) copy the data overself.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache# 4) update the posself.pos = other.posdef insert_kv(self, layer_idx, k, v):# Lazy initialize the cache here because we need to know the dtype/deviceif self.kv_cache is None:self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)# Insert new keys/values to the cache and return the full cache so farB, H, T_add, D = k.size()t0, t1 = self.pos, self.pos + T_add# Dynamically grow the cache if neededif t1 > self.kv_cache.size(4):t_needed = t1 + 1024 # as much as we need plus buffer of 1024t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024current_shape = list(self.kv_cache.shape)current_shape[4] = t_neededself.kv_cache.resize_(current_shape)# Insert k, v into the cacheself.kv_cache[layer_idx, 0, :, :, t0:t1] = kself.kv_cache[layer_idx, 1, :, :, t0:t1] = v# Return the full cached keys/values up to current position (as a view)key_view = self.kv_cache[layer_idx, 0, :, :, :t1]value_view = self.kv_cache[layer_idx, 1, :, :, :t1]# Increment pos after the last layer of the Transformer processesif layer_idx == self.kv_cache.size(0) - 1:self.pos = t1return key_view, value_view
3. 数据加载
分布式数据加载器:
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):"""Stream pretraining text from parquet files, tokenize, yield training batches."""assert split in ["train", "val"], "split must be 'train' or 'val'"ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()needed_tokens = B * T + 1 # +1 is because we also need the target at the last token# get the tokenizer and the bos tokentokenizer = get_tokenizer()bos_token = tokenizer.get_bos_token_id()# scratch buffer holds the tokens for one iterationtoken_buffer = deque() # we stream tokens on the right and pop from the left# infinite iterator over document batchesdef document_batches():while True:# batch will iterate in group size of the parquet files, usually e.g. 1024 rowsfor batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rowsfor i in range(0, len(batch), tokenizer_batch_size):yield batch[i:i+tokenizer_batch_size]batches = document_batches()batch_index = 0while True:# Accumulate enough tokens for one iteration before yielding.while len(token_buffer) < needed_tokens:doc_batch = next(batches)token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)for tokens in token_lists:token_buffer.extend(tokens)batch_index += 1# Move tokens from the deque into the scratch buffertokens = [token_buffer.popleft() for _ in range(needed_tokens)]# CUDA supports memory pinning for faster transfers between CPU and GPU:scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))# Create the inputs/targets as 1D tensorsinputs_cpu = scratch[:-1].to(dtype=torch.int32)targets_cpu = scratch[1:]# Reshape to 2D and move to GPU asyncinputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)yield inputs, targets
关键点:
- 滑动窗口:将长文本切分成固定长度序列
- 异步传输:使用
non_blocking=True加速数据传输 - 内存固定:
pin_memory加速CPU到GPU传输
训练流程
阶段1:预训练(Base Training)
目标:在大规模文本上学习语言模型
核心代码:
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):model_config = GPTConfig(**model_config_kwargs)model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:# calculate the number of iterations from the target flopsnum_iterations = round(target_flops / (num_flops_per_token * total_batch_size))print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:# calculate the number of iterations from the target param data ratiotarget_tokens = target_param_data_ratio * num_paramsnum_iterations = target_tokens // total_batch_sizeprint0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
关键超参数:
- Chinchilla定律:数据量 ≈ 20 × 参数量
- 学习率:分层设置
- 批次大小:总批次大小 = device_batch_size × world_size × grad_accum_steps
阶段2:中训练(Mid Training)
目标:引入对话格式、工具使用、选择题
混合训练数据:
train_dataset = TaskMixture([SmolTalk(split="train"), # 460K rows of general conversationsMMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACEGSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool useCustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversationsCustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of theseSimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
阶段3:监督微调(SFT)
目标:进一步优化对话能力和任务表现
train_ds = TaskMixture([ARC(subset="ARC-Easy", split="train"), # 2.3K rowsARC(subset="ARC-Challenge", split="train"), # 1.1K rowsGSM8K(subset="main", split="train"), # 8K rowsSmolTalk(split="train", stop=10_000), # 10K rows of smoltalkCustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversationsSimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K + שחumbles + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
对话渲染:
def render_conversation(self, conversation, max_tokens=2048):"""Tokenize a single Chat conversation (which we call a "doc" or "document" here).Returns:- ids: list[int] is a list of token ids of this rendered conversation- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on."""# ids, masks that we will return and a helper function to help build them up.ids, mask = [], []def add_tokens(token_ids, mask_val):if isinstance(token_ids, int):token_ids = [token_ids]ids.extend(token_ids)mask.extend([mask_val] * len(token_ids))# sometimes the first message is a system message...# => just merge it with the second (user) messageif conversation["messages"][0]["role"] == "system":# some conversation surgery is necessary here for now...conversation = copy.deepcopy(conversation) # avoid mutating the originalmessages = conversation["messages"]assert messages[1]["role"] == "user", "System message must be followed by a user message"messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]messages = messages[1:]else:messages = conversation["messages"]assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"# fetch all the special tokens we needbos = self.get_bos_token_id()user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")# now we can tokenize the conversationadd_tokens(bos, 0)for i, message in enumerate(messages):# some sanity checking here around assumptions, to prevent footgunsmust_be_from = "user" if i % 2 == 0 else "assistant"assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"# content can be either a simple string or a list of parts (e.g. containing tool calls)content = message["content"]if message["role"] == "user":assert isinstance(content, str), "User messages are simply expected to be strings"value_ids = self.encode(content)add_tokens(user_start, 0)add_tokens(value_ids, 0)add_tokens(user_end, 0)elif message["role"] == "assistant":add_tokens(assistant_start, 0)if isinstance(content, str):# simple string => simply add the tokensvalue_ids = self.encode(content)add_tokens(value_ids, 1)elif isinstance(content, list):for part in content:value_ids = self.encode(part["text"])if part["type"] == "text":# string part => simply add the tokensadd_tokens(value_ids, 1)elif part["type"] == "python":# python tool call => add the tokens inside <|python_start|> and <|python_end|>add_tokens(python_start, 1)add_tokens(value_ids, 1)add_tokens(python_end, 1)elif part["type"] == "python_output":# python output => add the tokens inside <|output_start|> and <|output_end|># none of these tokens are supervised because the tokens come from Python at test timeadd_tokens(output_start, 0)add_tokens(value_ids, 0)add_tokens(output_end, 0)else:raise ValueError(f"Unknown part type: {part['type']}")else:raise ValueError(f"Unknown content type: {type(content)}")add_tokens(assistant_end, 1)# truncate to max_tokens tokens MAX (helps prevent OOMs)ids = ids[:max_tokens]mask = mask[:max_tokens]return ids, mask
关键点:
- mask机制:只对assistant回复计算损失
- 特殊token:区分用户消息、助手消息、工具调用
关键技术细节
1. 优化器
混合优化器:对不同参数使用不同优化器
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):model_dim = self.config.n_embdddp, rank, local_rank, world_size = get_dist_info()# Separate out all parameters into 3 groups (matrix, embedding, lm_head)matrix_params = list(self.transformer.h.parameters())embedding_params = list(self.transformer.wte.parameters())lm_head_params = list(self.lm_head.parameters())assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)# Create the AdamW optimizer for the embedding and lm_head# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)dmodel_lr_scale = (model_dim / 768) ** -0.5if rank == 0:print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")adam_groups = [dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),]adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)# Create the Muon optimizer for the linear layersmuon_kwargs = dict(lr=matrix_lr, momentum=0.95)MuonFactory = DistMuon if ddp else Muonmuon_optimizer = MuonFactory(matrix_params, **muon_kwargs)# Combine them the two optimizers into one listoptimizers = [adamw_optimizer, muon_optimizer]for opt in optimizers:for group in opt.param_groups:group["initial_lr"] = group["lr"]return optimizers
Muon优化器:针对线性层,使用Newton-Schulz正交化:
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use aquintic iteration whose coefficients are selected to maximize the slope at zero. For the purposeof minimizing steps, it turns out to be empirically effective to keep increasing the slope atzero even beyond the point where the iteration no longer converges all the way to one everywhereon the interval. This iteration therefore does not produce UV^T but rather something like US'V^Twhere S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt modelperformance at all relative to UV^T, where USV^T = G is the SVD."""assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiachenga, b, c = (3.4445, -4.7750, 2.0315)X = G.bfloat16()if G.size(-2) > G.size(-1):X = X.mT# Ensure spectral norm is at most 1X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)# Perform the NS iterationsfor _ in range(steps):A = X @ X.mTB = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiachengX = a * X + B @ Xif G.size(-2) > G.size(-1):X = X.mTreturn X
2. 分布式训练
ZeRO-2式分片:
class DistAdamW(torch.optim.Optimizer):"""Distributed AdamW optimizer.In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction"""def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)super().__init__(param_groups, defaults)@torch.compile@torch.no_grad()def step(self):rank = dist.get_rank()world_size = dist.get_world_size()reduce_scatter_futures: list[torch.Future] = []all_reduce_futures: list[torch.Future] = []grad_slices = []for group in self.param_groups:params: list[Tensor] = group["params"]for base_i in range(len(params)):grad = params[base_i].gradrank_size = grad.shape[0] // world_sizegrad_slice = torch.empty_like(grad[:rank_size])reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())grad_slices.append(grad_slice)idx = 0for group in self.param_groups:beta1, beta2 = group['betas']eps = group['eps']wd = group['weight_decay']params = group['params']for base in range(len(params)):reduce_scatter_futures[idx].wait()p = params[base]rank_size = p.shape[0] // world_sizep_slice = p[rank * rank_size:(rank + 1) * rank_size]lr = group['lr'] * getattr(p, "lr_mul", 1.0)state = self.state[p]g_slice = grad_slices[idx]# State initif not state:state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)state['exp_avg'] = torch.zeros_like(p_slice)state['exp_avg_sq'] = torch.zeros_like(p_slice)exp_avg = state['exp_avg']exp_avg_sq = state['exp_avg_sq']state['step'] += 1t = state['step']# weight decayif wd != 0:eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)p_slice.mul_(1 - eff_weight_decay)# update running averagesexp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)# bias correctionsbias1 = 1 - beta1 ** tbias2 = 1 - beta2 ** t# compute stepdenom = exp_avg_sq.sqrt().add_(eps)step_size = lr * (torch.sqrt(bias2) / bias1)update = exp_avg.div(denom).mul_(step_size)p_slice.add_(other=update, alpha=-1.0)idx += 1all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())torch.futures.collect_all(all_reduce_futures).wait()
关键点:
- reduce_scatter:梯度平均
- all_gather:参数同步
- 异步通信:提高并行度
3. 工具使用(Tool Use)
Python执行工具:
def eval_with_timeout(formula, max_time=3):try:with timeout(max_time, formula):with warnings.catch_warnings():warnings.simplefilter("ignore", SyntaxWarning)return eval(formula)except Exception as e:signal.alarm(0)# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usagereturn Nonedef use_calculator(expr):"""Evaluate a Python expression safely.Supports both math expressions and string operations like .count()"""# Remove commas from numbersexpr = expr.replace(",", "")# Check if it's a pure math expression (old behavior)if all([x in "0123456789*+-/.() " for x in expr]):if "**" in expr: # disallow power operatorreturn Nonereturn eval_with_timeout(expr)# Check if it's a string operation we support# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parensallowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "if not all([x in allowed_chars for x in expr]):return None# Disallow dangerous patternsdangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file','input', 'raw_input', 'globals', 'locals', 'vars', 'dir','getattr', 'setattr', 'delattr', 'hasattr']expr_lower = expr.lower()if any(pattern in expr_lower for pattern in dangerous_patterns):return None# Only allow .count() method for now (can expand later)if '.count(' not in expr:return None# Evaluate with timeoutreturn eval_with_timeout(expr)
在生成中调用工具:
# Handle tool logic
if next_token == python_start:state.in_python_block = Truestate.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:state.in_python_block = Falseif state.python_expr_tokens:expr = self.tokenizer.decode(state.python_expr_tokens)result = use_calculator(expr)if result is not None:result_tokens = self.tokenizer.encode(str(result))state.forced_tokens.append(output_start)state.forced_tokens.extend(result_tokens)state.forced_tokens.append(output_end)state.python_expr_tokens = []
elif state.in_python_block:state.python_expr_tokens.append(next_token)
运行流程
一键运行
#!/bin/bash# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR# -----------------------------------------------------------------------------
# Python venv setup with uv# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
uv sync --extra gpu
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate# -----------------------------------------------------------------------------
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# 1) Make sure to first log in to wandb, e.g. run:
# `wandb login`
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
# `WANDB_RUN=d26 bash speedrun.sh`
if [ -z "$WANDB_RUN" ]; then# by default use "dummy" : it's handled as a special case, skips logging to wandbWANDB_RUN=dummy
fi# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
python -m nanochat.report reset# -----------------------------------------------------------------------------
# Tokenizer# Install Rust / Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"# Build the rustbpe Tokenizer
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval# -----------------------------------------------------------------------------
# Base model (pretraining)# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; thencurl -L -o eval_bundle.zip $EVAL_BUNDLE_URLunzip -q eval_bundle.ziprm eval_bundle.zipmv eval_bundle $NANOCHAT_BASE_DIR
fi# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
Windows/WSL/PowerShell 提示
- 推荐在 Linux/WSL 跑完整流程;Windows 原生适合小规模验证。
- PowerShell 快速试跑(CPU/小模型示例):
# 进入仓库根目录
cd E:\open_src2\nanochat# 可选:构建 RustBPE(如需自训练分词器)
pip install maturin
maturin develop --release --manifest-path rustbpe/Cargo.toml# 安装 Python 依赖(建议先装好 PyTorch,对应你的 CUDA/CPU 环境)
pip install -e .# 运行一个极小训练以验证流程
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20# 启动 Web 聊天(如已有 checkpoint,可直接推理)
python -m scripts.chat_web
技术小结
核心技术要点
- BPE分词:平衡压缩率与计算效率
- Transformer架构:自注意力 + 前馈网络
- RoPE位置编码:Rotary Position Embedding
- 混合优化:Muon(矩阵) + AdamW(嵌入)
- 分布式训练:ZeRO-2分片
- KV Cache:加速推理
- 工具使用:Python代码执行
训练阶段
- 预训练:大规模文本学习语言模型
- 中训练:引入对话格式和特殊能力
- SFT:微调对话能力
- RL(可选):使用奖励模型优化
关键要点
- 数据规模:遵循Chinchilla定律(20×参数量)
- 计算效率:
torch.compile加速 - 内存优化:梯度累积、混合精度
- 代码质量:简洁、易读、可扩展
下一步
- 尝试修改模型架构
- 调整超参数
- 添加新的数据集
- 实现新的评估任务
- 探索强化学习
560,988,160(约5.61亿)这个参数总数是怎么计算出来的?
这实际上是基于Transformer模型的结构参数,严格、系统地统计模型各个层、各个模块的可训练参数数量。我们知道 nanochat 的以下结构信息:
- 层数(20层)
- embedding维度(model_dim=1280)
- 词表大小(vocab_size=65536)
- 注意力头数(num_heads=10)
- 每步token数/每层前后维度关系……
- 通常还有前向/后向投影,MLP中间扩张维度、输出层等
可以计算:
- Embedding层:83,886,080
- 输出层:83,886,080
- Transformer主体:393,216,000
- LayerNorm少量
总计:
83,886,080 + 83,886,080 + 393,216,000 = 560,988,160
560,988,160 就是:
- 词嵌入(输入+输出)
- 每层 attention+MLP
- 总共20层 + 以上加起来的参数和。
- 如果使用权重共享(tied embedding),输出层参数不会被重复计算。
- 这里是最经典的GPT-style Transformer参数统计方式。不同实现如FeedForward维度变化、是否有bias、RMSNorm等会略有差异。
具体来说:
# 模型的词表大小,表示一共有多少个不同的token(词/字/符号等),比如65536常用于中等多语模型
vocab_size = 65536# Transformer模型的层数(即有多少个Transformer Block,一般层数越多,能力越强)
n_layer = 20# 隐藏层维度(也称为模型宽度,通常等于embedding/hidden state/输出层维度)
model_dim = 1280# MLP隐藏层维度,通常是model_dim的4倍(即 "Feed Forward Network" 的expansion比例)
mlp_dim = 5120 # 通常4倍于模型维度# 一层Transformer中所有注意力子层(多头自注意力的Q/K/V/O)对应的线性投影参数总量
# Q: Query, K: Key, V: Value, O: Output projection,各自是一个 model_dim x model_dim 的weight,共4份
attention_proj = 4 * model_dim * model_dim # 总共4个权重矩阵的参数量# 一层Transformer中MLP部分的参数总量
# 由两个线性层组成,输入 model_dim -> 隐藏 mlp_dim -> 输出 model_dim
# 所以有两个参数矩阵:第一个 model_dim x mlp_dim,第二个 mlp_dim x model_dim
mlp_proj = 2 * model_dim * mlp_dim# 每层Transformer Block的参数总量 = 注意力部分 + 前馈神经网络(MLP)部分
per_layer = attention_proj + mlp_proj# 总参数量 = token embedding参数 + 输出层参数 + 每层block参数 * 层数
# embedding参数 = vocab_size x model_dim
# 输出层参数 = vocab_size x model_dim(通常与embedding层共用一套参数,但这里各自计入一次;如有权重共享应只加一次)
# transformer主体参数 = n_layer * per_layer
total = vocab_size * model_dim * 2 + n_layer * per_layer# 输出总参数数量,即本模型(20层,1280宽度,65536词表)的参数量
print(total) # 输出:560988160
在上面代码中提到:MLP隐藏层维度,通常是model_dim的4倍。那么:
什么是 MLP隐藏层?
MLP(全称:Multi-Layer Perceptron,多层感知机),在Transformer中的每个Block(层)都有一个MLP子模块,也称为前馈神经网络(Feed Forward Network, FFN)。
对于Transformer而言,它的结构通常是:
- 多头自注意力(Multi-Head Self Attention, MHSA)
- MLP/FFN模块
MLP一般包括2个全连接(Linear/Dense)层和一个激活函数(通常是GELU/Relu等):
- 输入维度为 model_dim(隐状态维度,比如1280)
- 经过第一个全连接层,升维到 mlp_hidden_dim
- 经过激活函数(如GELU)
- 再经过第二个全连接层,还原回原维度(mlp_hidden_dim → model_dim)
伪代码结构:
def MLP(x):x = Linear1(x) # model_dim → mlp_hidden_dimx = activation(x)x = Linear2(x) # mlp_hidden_dim → model_dimreturn x
MLP隐藏层,就是指中间那一层"升维后的"维度,通常叫 mlp_hidden_dim 或 mlp_dim。
为什么 MLP隐藏层维度通常是 model_dim 的4倍?
行业标准和经验:
几乎所有主流的Transformer论文(原始Transformer、BERT、GPT、Llama 等)都采用 mlp_dim = 4 × model_dim 作为“最佳经验值”。
- 例如,model_dim=1280(主干宽度),则mlp_hidden_dim=5120。
- 如果model_dim=1024,则mlp_hidden_dim=4096。
为什么4倍?
- 增加表示能力
Transformer的自注意力模块更关注token间的信息交换,而MLP负责对每个token内部特征做复杂投影和非线性组合。- 扩大MLP宽度,相当于给每个token分配了更高容量的“单token特征处理器”,提升每个token表征的复杂性。
- 实证最优
各种论文和大规模实验(如GPT-3、Llama设计)表明,4倍左右最能权衡能力提升和显存/效率。- 太小,网络表达能力不足。
- 太大,参数和计算量暴涨,而性能提升边际效益变低。
- 结构均衡
1倍(即和model_dim同宽)效果较差,2-4倍提升显著,但大于4、8倍收益变小——“4”是个工程上经过验证的“甜蜜点”。
本质原因
Transformer的原理强调“全局建模”(通过Attention)和“局部特征非线性拓展(MLP)”,两者黑盒功能不同。
设置较大MLP扩展,能使每个token的表达在全局混合后再做更复杂的加工,语言模型因此具备更强的语义抽象和记忆能力。
通常各类GPT/LLM的参数分布比例大致如下:
| 模型 | Embedding | Transformer Block (含Self-Attn/FFN) | 输出层 | 备注 |
|---|---|---|---|---|
| 本模型(d20) | 83.9M (~15%) | 393.2M (~70%) | 83.9M (~15%) | 输入/输出占三成,大头在transformer |
| GPT-3 175B | ~617M (<1%) | ~99% | ~617M (<1%) | 层数/宽度极大,embedding占比变小 |
| DeepSeek-V3 1.3B | 262.1M (~20%) | 1,168M | 262.1M (~20%) | 更大词表,embedding权重更高占比 |
我们以大家熟悉的 GPT-2/GPT-3/ChatGPT (对应OpenAI的GPT-3.5/4)、DeepSeek-v3 等主流开源模型进行参数量对比,同时对模型结构主要组成部分参数量进行总结:
| 模型 | 总参数数量 | 层数 | 维度(d_model) | 词表大小 | 主要组成 | Embedding参数 | Transformer Block 参数 | 输出层参数 | 典型用途/备注 |
|---|---|---|---|---|---|---|---|---|---|
| 本模型(d20) | 560,988,160 | 20 | 1280 | 65,536 | GPT-风格 | 83.9M | 393.2M | 83.9M | 微型研究/教材 |
| GPT-2 Small | 117,000,000 | 12 | 768 | 50,257 | GPT-2 | 38.5M | 78.9M | 38.5M | 入门/微型英文 |
| GPT-2 Medium | 345,000,000 | 24 | 1024 | 50,257 | GPT-2 | 51.4M | 242M | 51.4M | GPT-2 2/小型 |
| GPT-2 Large | 762,000,000 | 36 | 1280 | 50,257 | GPT-2 | 64.2M | 633M | 64.2M | 大号GPT-2 |
| GPT-2 XL | 1,542,000,000 | 48 | 1600 | 50,257 | GPT-2 | 80.4M | 1,382M | 80.4M | 超大GPT-2 |
| GPT-3 125M | 125,000,000 | 12 | 768 | 50,257 | GPT-3 tiny | 38.5M | 78.9M | 38.5M | 类GPT-2 small |
| GPT-3 350M | 350,000,000 | 24 | 1024 | 50,257 | GPT-3 mini | 51.4M | 242M | 51.4M | 类GPT-2 medium |
| GPT-3 1.3B | 1,300,000,000 | 24 | 2048 | 50,257 | GPT-3 Small | 103M | 1,094M | 103M | 微型GPT-3 |
| GPT-3 6.7B | 6,700,000,000 | 32 | 4096 | 50,257 | GPT-3 Medium | 206M | 6,288M | 206M | 小型GPT-3 |
| GPT-3 175B | 175,000,000,000 | 96 | 12,288 | 50,257 | GPT-3 | 617M | 173.8B | 617M | GPT-3 Flagship |
| DeepSeek-V2 1.3B | 1,300,000,000 | 24 | 2048 | 69,376 | DeepSeek V2 | 140M | ≈1,020M | 140M | 支持多语种 |
| DeepSeek-V3 1.3B | 1,300,000,000 | 24 | 2048 | 128,384 | DeepSeek V3 | 262.1M | ≈1,168M | 262.1M | 支持更大多语词表 |
| ChatGPT(GPT-3.5) | ~6,000,000,000? | ~96 | ~6,000-12,288? | 未公开 | OpenAI服务 | - | - | - | API/服务大模型 |
下面是一篇详细的技术文章,系统讨论了Embedding、Transformer Block(含多头注意力和前馈网络)、输出层参数的意义,并给出了本模型(d20)及其他常见模型的参数结构对比。本文适合有一定机器学习或NLP基础的读者作为学习材料或技术参考。
由此可见:
1. Embedding与输出层
- 小模型、词表大,embedding占比高。 本模型因为词表足够大,embedding和输出层各占到参数总数的15%左右。DeepSeek等支持极大词表(多语种、代码等),embedding和输出层参数占比就更高。
- 大模型,embedding占比递减。 如GPT-3 175B时,embedding+输出层<2%,绝大多数参数在"骨干"部分。
2. Transformer Block
- 是大模型参数攀升的根本所在。 层数、宽度提升明显推高参数水平,体现了大模型“主要参数在Block”的本质。
- **表达能力集中于此。**Block主导了模型的泛化、记忆和推理能力。
3. 模型设计权衡
- 想节省参数,建议优先压缩embedding/输出层(如共享embedding)。
- 若追求能力,主要提升block部分规模更为划算,embedding增长仅带来性质改变(如更好多语种/更好rare token coverage)。
- 假如模型是学生:
“embedding增长”像是学生买了新字典,能查到更冷门的单词——他不一定理解词的意思,只是“不漏掉”意思。
“Block变大”像是大脑神经元增多,逻辑推理、理解能力提高了——不单知道生词,还能举一反三、深刻推理。
GPT-3、GPT-4、Llama等超大模型,参数主力集中在Block部分,即使词表数十万,embedding参数占比也很小 - 输出层:多数情况下,输出层参数增大主要是解决token覆盖和输出精度问题,本身对模型“理解/推理抽象能力”提升有限。(理解/推理能力由Transformer block主导)
如果希望支持更多token/语种/领域,必须增大输出层参数;在多语种/跨领域/代码通用大模型场景下,输出层参数是充分表达全词表概率分布不可省略的关键 - “权重共享”在指的就是输入层(token embedding)和输出层(language modeling head)可以共用同一组权重矩阵。一个几十万词表的模型,embedding/输出各自存一遍参数,非常占空间;共享则直接省掉一半。输入和输出层都用同一组权重(embedding lookup和输出softmax共用)的实现方式:输出层权重等于词嵌入矩阵的转置。
大语言模型参数按Embedding、Transformer Block与输出层划分,清晰反映结构设计思路与资源需求。以本d20模型为例,其参数量与主流小~中型GPT类似,更大词表带来更高词嵌入/输出层参数,而大模型参数主力始终集中于transformer块。
这种参数结构认知,不仅有助于架构选型、资源预估,更有助于理解模型训练和压缩的底层逻辑。未来的多模态、超大规模模型,参数结构依然遵循这一经典“三分法”,设计者可针对实际需求优化每一部分参数,提升性价比与能力表现。
常见问题(FAQ)
- 需要多少 GPU? 最少1张可跑但很慢;推荐 8×H100 跑 $100 档 d20 约 4 小时。
- 显存不够怎么办? 降低
--device_batch_size;必要时降低--max_seq_len或模型深度/宽度,脚本会用梯度累积弥补吞吐。 - 分词器一定要训练吗? 可直接使用 tiktoken 推理;但自训练更贴合你的数据分布(压缩率更优)。
- 如何确认启用了 GQA/MQA? 查看
GPTConfig.n_kv_head与n_head的关系;n_kv_head < n_head即为 GQA/MQA 风格。 - 为什么推理需要 KV Cache? 否则每步都会重算历史 token 的注意力;KV Cache 复用历史 K/V,只算新增部分。
- 工具调用安全吗? 受限且有超时,仅允许算术与受限字符串
.count();包含危险片段会被拒绝。 - 日志与指标怎么看? 设置
WANDB_RUN使用 wandb 记录;同时会在~/.cache/nanochat/report/写入 markdown 报告。 - 如何仅在 CPU 上试跑? 参考上文 PowerShell 样例;把
--depth、--max_seq_len、--device_batch_size、--num_iterations调小即可。 - 如何用 WebUI 体验? 训练(或下载)checkpoint 后运行
python -m scripts.chat_web,浏览器访问输出地址。 - 如何切换更大模型? 提高
--depth并按 Chinchilla 比例增加数据分片;根据显存调小--device_batch_size。
参考资源
- 原仓库:https://github.com/karpathy/nanochat
- Transformer论文:Attention Is All You Need
- Chinchilla论文:Training Compute-Optimal Large Language Models
- RoPE论文:RoFormer: Enhanced Transformer with Rotary Position Embedding
本教程基于nanochat代码库编写,适合AI初学者深入理解LLM实现。
