【第四章:大模型(LLM)】09.最强开源大模型:Llama3 原理介绍与实现-(6)Llama2 Llama3代码实现
第四章:大模型(LLM)
第九部分:最强开源大模型:Llama3 原理介绍与实现
第六节:Llama2 & Llama3代码实现
特性 | Llama2 | Llama3 | 说明 |
---|---|---|---|
发布时间 | 2023年7月 | 2024年4月 | Llama2 是 Meta 首次开源商用许可的 LLM,Llama3 是大幅升级版本 |
参数规模 | 7B / 13B / 70B | 8B / 70B(未来还会有更大模型) | 两代都有小+大双配置 |
预训练数据量 | ~2T tokens | ~15T tokens | 数据规模提升近 8 倍 |
词表大小 (Tokenizer) | 32k BPE | 128k SentencePiece | 支持更多语言和符号,编码效率更高 |
归一化 (Normalization) | RMSNorm | RMSNorm | 两代都用 RMSNorm 替代 LayerNorm,更高效 |
位置编码 (Positional Encoding) | RoPE(旋转位置编码) | RoPE + 更长上下文扩展 | Llama3 支持 8k 甚至扩展至 16k 上下文 |
注意力机制 | Multi-Head Attention (MHA) | Grouped-Query Attention (GQA) | Llama3 在大模型中采用 GQA,减少 KV Cache 内存与加速推理 |
激活函数 | SwiGLU | SwiGLU | 两代保持一致,替代传统 GELU |
KV Cache 优化 | 基础缓存机制 | KV Cache + 高效并行化 | Llama3 针对长上下文推理优化 |
训练优化 | ZeRO + FSDP + FlashAttention | 更完善的分布式训练(ZeRO Stage3, 先进调度, 高效通信) | Llama3 更注重推理/训练的吞吐比 |
推理性能 | 较好 | 显著提升(特别是大模型) | 得益于 GQA 和高效 KV Cache |
应用表现 | 英文、部分多语言 | 英文、多语言、编程能力更强 | Llama3 适应多语言和代码生成任务更好 |
代码
"""
Llama2 & Llama3 minimal reference implementation in PyTorch
-----------------------------------------------------------------
This single-file implementation aims to be:
- Educational: clear, commented, and faithful to core design choices.
- Practical: supports KV cache, RoPE, RMSNorm, SwiGLU, and GQA (Grouped Multi-Query Attention).
- Flexible: one Config drives either a Llama2-like (MHA) or Llama3-like (GQA) model.DISCLAIMER
- Dimensions and defaults below are illustrative; do not expect parity with official checkpoints.
- For real training/inference use highly optimized kernels (FlashAttention, Triton, xFormers, etc.).Author: your friendly AI
"""
from __future__ import annotationsimport math
from dataclasses import dataclass
from typing import Optional, Tupleimport torch
import torch.nn as nn
import torch.nn.functional as F# -------------------------------
# Config
# -------------------------------@dataclass
class ModelConfig:vocab_size: int = 32000dim: int = 4096 # model width (aka d_model)n_layers: int = 32n_heads: int = 32n_kv_heads: int = 32 # == n_heads => MHA (Llama2-style). < n_heads => GQA (Llama3-style)max_seq_len: int = 8192 # context lengthrope_base: float = 10000.0 # RoPE base; change for scaling strategiesffn_multiplier: float = 2.6667 # ~8/3 used in LLaMA with (Swi)GLUnorm_eps: float = 1e-6dropout: float = 0.0bias: bool = False # projections bias; meta models use bias=Falsetie_embeddings: bool = Truedef __post_init__(self):assert self.dim % self.n_heads == 0assert self.n_kv_heads <= self.n_heads# -------------------------------
# RMSNorm
# -------------------------------class RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def forward(self, x: torch.Tensor) -> torch.Tensor:# root mean square along last dimrms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()return self.weight * (x * rms)# -------------------------------
# Rotary Positional Embeddings (RoPE)
# -------------------------------class RotaryEmbedding(nn.Module):"""Precomputes cos/sin tables and applies RoPE to q,k.Supports base scaling via rope_base."""def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0, device=None):super().__init__()assert dim % 2 == 0, "RoPE dim must be even"self.dim = dimself.max_seq_len = max_seq_lenself.base = baseinv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))t = torch.arange(max_seq_len, device=device).float()freqs = torch.einsum('i,j->ij', t, inv_freq) # [T, dim/2]self.register_buffer('cos', freqs.cos(), persistent=False)self.register_buffer('sin', freqs.sin(), persistent=False)def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:"""Apply RoPE to last dimension of x (head_dim) using provided positions.x: [..., T, dim]positions: [T] or [..., T] integer positions"""cos = self.cos.index_select(0, positions).unsqueeze(-1) # [..., T, dim/2, 1]sin = self.sin.index_select(0, positions).unsqueeze(-1)x = x.view(*x.shape[:-1], self.dim // 2, 2)x1, x2 = x.unbind(dim=-1)# rotation: (x1, x2) -> (x1*cos - x2*sin, x2*cos + x1*sin)rot_x1 = x1 * cos.squeeze(-1) - x2 * sin.squeeze(-1)rot_x2 = x2 * cos.squeeze(-1) + x1 * sin.squeeze(-1)return torch.stack((rot_x1, rot_x2), dim=-1).flatten(start_dim=-2)# -------------------------------
# KV Cache helper
# -------------------------------class KVCache:"""Simple KV cache per layer.Shapes stored as [B, n_kv_heads, T_total, head_dim]"""def __init__(self):self.k = Noneself.v = Noneself.len = 0def append(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:# k,v: [B, n_kv_heads, T_new, head_dim]if self.k is None:self.k = kself.v = vself.len = k.size(2)else:self.k = torch.cat([self.k, k], dim=2)self.v = torch.cat([self.v, v], dim=2)self.len += k.size(2)return self.k, self.v# -------------------------------
# Attention with GQA & KV cache
# -------------------------------class Attention(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.cfg = cfgself.n_heads = cfg.n_headsself.n_kv_heads = cfg.n_kv_headsself.head_dim = cfg.dim // cfg.n_headsself.dropout = cfg.dropoutself.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=cfg.bias)self.wk = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=cfg.bias)self.wv = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=cfg.bias)self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=cfg.bias)self.rope = RotaryEmbedding(self.head_dim, cfg.max_seq_len, cfg.rope_base)# precompute head->group mapping for GQA# each group shares one KV headidx = torch.arange(self.n_heads)# map heads to kv heads proportionallyself.register_buffer('head_to_kv', torch.div(idx * self.n_kv_heads, self.n_heads, rounding_mode='floor'))def forward(self, x: torch.Tensor, cache: Optional[KVCache] = None, start_pos: int = 0) -> torch.Tensor:B, T, D = x.shapeH, HKV, hd = self.n_heads, self.n_kv_heads, self.head_dimq = self.wq(x).view(B, T, H, hd).transpose(1, 2) # [B, H, T, hd]k = self.wk(x).view(B, T, HKV, hd).transpose(1, 2) # [B, HKV, T, hd]v = self.wv(x).view(B, T, HKV, hd).transpose(1, 2) # [B, HKV, T, hd]# positions for RoPE: absolute indices in the sequence (cache length .. cache length+T-1)pos = torch.arange(start_pos, start_pos + T, device=x.device)q = self.rope(q.transpose(2, 3), pos).transpose(2, 3) # rotate along last dimk = self.rope(k.transpose(2, 3), pos).transpose(2, 3)# append to cacheif cache is not None:k, v = cache.append(k, v) # [B, HKV, T_total, hd]total_T = cache.lenelse:total_T = T# expand K/V to per-head using head_to_kv mapping# K/V: [B, HKV, T_total, hd] -> [B, H, T_total, hd]kv_index = self.head_to_kv.view(1, -1, 1, 1).expand(B, H, total_T, 1)K = k.transpose(1, 2) # [B, T_total, HKV, hd]V = v.transpose(1, 2)K = K.gather(2, kv_index.expand(B, H, total_T, hd)) # [B, H, T_total, hd]V = V.gather(2, kv_index.expand(B, H, total_T, hd))K = K.transpose(1, 2) # [B, H, T_total, hd]V = V.transpose(1, 2)# use PyTorch SDPA (enables FlashAttention when available)attn_mask = torch.triu(torch.full((T, total_T), float('-inf'), device=x.device), diagonal=1 + start_pos)y = F.scaled_dot_product_attention(q, K, V, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0)y = y.transpose(1, 2).contiguous().view(B, T, H * hd)y = self.wo(y)return y# -------------------------------
# SwiGLU MLP
# -------------------------------class SwiGLU(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()inner = int(cfg.dim * cfg.ffn_multiplier)self.w1 = nn.Linear(cfg.dim, inner, bias=cfg.bias) # mainself.w2 = nn.Linear(cfg.dim, inner, bias=cfg.bias) # gateself.w3 = nn.Linear(inner, cfg.dim, bias=cfg.bias)self.dropout = nn.Dropout(cfg.dropout)def forward(self, x: torch.Tensor) -> torch.Tensor:a = self.w1(x)b = self.w2(x)gated = a * torch.sigmoid(b) * b # Swish(b) = b * sigmoid(b)return self.w3(self.dropout(gated))# -------------------------------
# Transformer Block (Pre-Norm)
# -------------------------------class TransformerBlock(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps)self.attn = Attention(cfg)self.ffn_norm = RMSNorm(cfg.dim, cfg.norm_eps)self.ffn = SwiGLU(cfg)self.dropout = nn.Dropout(cfg.dropout)def forward(self, x: torch.Tensor, cache: Optional[KVCache] = None, start_pos: int = 0) -> torch.Tensor:h = self.attn(self.attn_norm(x), cache=cache, start_pos=start_pos)x = x + self.dropout(h)h2 = self.ffn(self.ffn_norm(x))x = x + self.dropout(h2)return x# -------------------------------
# LlamaModel
# -------------------------------class LlamaModel(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.cfg = cfgself.tok_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim)self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])self.norm = RMSNorm(cfg.dim, cfg.norm_eps)self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)if cfg.tie_embeddings:self.lm_head.weight = self.tok_embeddings.weight@torch.no_grad()def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 1.0,top_k: Optional[int] = None) -> torch.Tensor:"""Greedy/Top-k sampling with KV cache.input_ids: [B, T]Returns: [B, T+max_new_tokens]"""device = next(self.parameters()).deviceB, T = input_ids.shape# initialize caches per layercaches = [KVCache() for _ in range(self.cfg.n_layers)]cur_ids = input_idsout = [input_ids]start_pos = 0for step in range(max_new_tokens):x = self.tok_embeddings(cur_ids)for i, blk in enumerate(self.blocks):x = blk(x, cache=caches[i], start_pos=start_pos)x = self.norm(x)logits = self.lm_head(x)[:, -1, :] # last token logitsif temperature != 1.0:logits = logits / max(1e-8, temperature)if top_k is not None:v, ix = torch.topk(logits, top_k)logits = torch.full_like(logits, float('-inf')).scatter_(1, ix, v)probs = F.softmax(logits, dim=-1)next_id = torch.multinomial(probs, num_samples=1) # [B,1]out.append(next_id)cur_ids = next_idstart_pos += 1return torch.cat(out, dim=1)def forward(self, input_ids: torch.Tensor) -> torch.Tensor:"""Training forward without external cache (teacher forcing)."""B, T = input_ids.shapex = self.tok_embeddings(input_ids)# per-layer temporary caches to reuse attention mask/pos handlingcaches = [None] * self.cfg.n_layersfor i, blk in enumerate(self.blocks):x = blk(x, cache=caches[i], start_pos=0) # no accumulation across steps in full teacher forcingx = self.norm(x)logits = self.lm_head(x)return logits# -------------------------------
# Example configs: Llama2-like vs Llama3-like
# -------------------------------LLAMA2_7B = ModelConfig(vocab_size=32000,dim=4096,n_layers=32,n_heads=32,n_kv_heads=32, # MHA (no grouping)max_seq_len=4096,rope_base=10000.0,ffn_multiplier=2.6667,norm_eps=1e-6,dropout=0.0,bias=False,
)LLAMA3_8B_LIKE = ModelConfig(vocab_size=128000, # Llama3 uses larger tokenizer; illustrative valuedim=4096,n_layers=32,n_heads=32,n_kv_heads=8, # GQA: 4 Q-heads share 1 KV headmax_seq_len=8192,rope_base=10000.0, # Scaling strategies can modify thisffn_multiplier=2.6667,norm_eps=1e-6,dropout=0.0,bias=False,
)# -------------------------------
# Quick usage (pseudo)
# -------------------------------
if __name__ == "__main__":device = 'cuda' if torch.cuda.is_available() else 'cpu'# Pick a configcfg = LLAMA3_8B_LIKEmodel = LlamaModel(cfg).to(device)model.eval()# dummy idsx = torch.randint(0, cfg.vocab_size, (1, 10), device=device)with torch.no_grad():y = model.generate(x, max_new_tokens=5, temperature=1.0, top_k=50)print("Generated shape:", y.shape)