当前位置: 首页 > news >正文

【第四章:大模型(LLM)】09.最强开源大模型:Llama3 原理介绍与实现-(6)Llama2 Llama3代码实现

第四章:大模型(LLM)

第九部分:最强开源大模型:Llama3 原理介绍与实现

第六节:Llama2 & Llama3代码实现


特性Llama2Llama3说明
发布时间2023年7月2024年4月Llama2 是 Meta 首次开源商用许可的 LLM,Llama3 是大幅升级版本
参数规模7B / 13B / 70B8B / 70B(未来还会有更大模型)两代都有小+大双配置
预训练数据量~2T tokens~15T tokens数据规模提升近 8 倍
词表大小 (Tokenizer)32k BPE128k SentencePiece支持更多语言和符号,编码效率更高
归一化 (Normalization)RMSNormRMSNorm两代都用 RMSNorm 替代 LayerNorm,更高效
位置编码 (Positional Encoding)RoPE(旋转位置编码)RoPE + 更长上下文扩展Llama3 支持 8k 甚至扩展至 16k 上下文
注意力机制Multi-Head Attention (MHA)Grouped-Query Attention (GQA)Llama3 在大模型中采用 GQA,减少 KV Cache 内存与加速推理
激活函数SwiGLUSwiGLU两代保持一致,替代传统 GELU
KV Cache 优化基础缓存机制KV Cache + 高效并行化Llama3 针对长上下文推理优化
训练优化ZeRO + FSDP + FlashAttention更完善的分布式训练(ZeRO Stage3, 先进调度, 高效通信)Llama3 更注重推理/训练的吞吐比
推理性能较好显著提升(特别是大模型)得益于 GQA 和高效 KV Cache
应用表现英文、部分多语言英文、多语言、编程能力更强Llama3 适应多语言和代码生成任务更好

代码

"""
Llama2 & Llama3 minimal reference implementation in PyTorch
-----------------------------------------------------------------
This single-file implementation aims to be:
- Educational: clear, commented, and faithful to core design choices.
- Practical: supports KV cache, RoPE, RMSNorm, SwiGLU, and GQA (Grouped Multi-Query Attention).
- Flexible: one Config drives either a Llama2-like (MHA) or Llama3-like (GQA) model.DISCLAIMER
- Dimensions and defaults below are illustrative; do not expect parity with official checkpoints.
- For real training/inference use highly optimized kernels (FlashAttention, Triton, xFormers, etc.).Author: your friendly AI
"""
from __future__ import annotationsimport math
from dataclasses import dataclass
from typing import Optional, Tupleimport torch
import torch.nn as nn
import torch.nn.functional as F# -------------------------------
# Config
# -------------------------------@dataclass
class ModelConfig:vocab_size: int = 32000dim: int = 4096                # model width (aka d_model)n_layers: int = 32n_heads: int = 32n_kv_heads: int = 32           # == n_heads => MHA (Llama2-style). < n_heads => GQA (Llama3-style)max_seq_len: int = 8192        # context lengthrope_base: float = 10000.0     # RoPE base; change for scaling strategiesffn_multiplier: float = 2.6667 # ~8/3 used in LLaMA with (Swi)GLUnorm_eps: float = 1e-6dropout: float = 0.0bias: bool = False             # projections bias; meta models use bias=Falsetie_embeddings: bool = Truedef __post_init__(self):assert self.dim % self.n_heads == 0assert self.n_kv_heads <= self.n_heads# -------------------------------
# RMSNorm
# -------------------------------class RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def forward(self, x: torch.Tensor) -> torch.Tensor:# root mean square along last dimrms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()return self.weight * (x * rms)# -------------------------------
# Rotary Positional Embeddings (RoPE)
# -------------------------------class RotaryEmbedding(nn.Module):"""Precomputes cos/sin tables and applies RoPE to q,k.Supports base scaling via rope_base."""def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0, device=None):super().__init__()assert dim % 2 == 0, "RoPE dim must be even"self.dim = dimself.max_seq_len = max_seq_lenself.base = baseinv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))t = torch.arange(max_seq_len, device=device).float()freqs = torch.einsum('i,j->ij', t, inv_freq)  # [T, dim/2]self.register_buffer('cos', freqs.cos(), persistent=False)self.register_buffer('sin', freqs.sin(), persistent=False)def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:"""Apply RoPE to last dimension of x (head_dim) using provided positions.x: [..., T, dim]positions: [T] or [..., T] integer positions"""cos = self.cos.index_select(0, positions).unsqueeze(-1)  # [..., T, dim/2, 1]sin = self.sin.index_select(0, positions).unsqueeze(-1)x = x.view(*x.shape[:-1], self.dim // 2, 2)x1, x2 = x.unbind(dim=-1)# rotation: (x1, x2) -> (x1*cos - x2*sin, x2*cos + x1*sin)rot_x1 = x1 * cos.squeeze(-1) - x2 * sin.squeeze(-1)rot_x2 = x2 * cos.squeeze(-1) + x1 * sin.squeeze(-1)return torch.stack((rot_x1, rot_x2), dim=-1).flatten(start_dim=-2)# -------------------------------
# KV Cache helper
# -------------------------------class KVCache:"""Simple KV cache per layer.Shapes stored as [B, n_kv_heads, T_total, head_dim]"""def __init__(self):self.k = Noneself.v = Noneself.len = 0def append(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:# k,v: [B, n_kv_heads, T_new, head_dim]if self.k is None:self.k = kself.v = vself.len = k.size(2)else:self.k = torch.cat([self.k, k], dim=2)self.v = torch.cat([self.v, v], dim=2)self.len += k.size(2)return self.k, self.v# -------------------------------
# Attention with GQA & KV cache
# -------------------------------class Attention(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.cfg = cfgself.n_heads = cfg.n_headsself.n_kv_heads = cfg.n_kv_headsself.head_dim = cfg.dim // cfg.n_headsself.dropout = cfg.dropoutself.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=cfg.bias)self.wk = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=cfg.bias)self.wv = nn.Linear(cfg.dim, self.n_kv_heads * self.head_dim, bias=cfg.bias)self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=cfg.bias)self.rope = RotaryEmbedding(self.head_dim, cfg.max_seq_len, cfg.rope_base)# precompute head->group mapping for GQA# each group shares one KV headidx = torch.arange(self.n_heads)# map heads to kv heads proportionallyself.register_buffer('head_to_kv', torch.div(idx * self.n_kv_heads, self.n_heads, rounding_mode='floor'))def forward(self, x: torch.Tensor, cache: Optional[KVCache] = None, start_pos: int = 0) -> torch.Tensor:B, T, D = x.shapeH, HKV, hd = self.n_heads, self.n_kv_heads, self.head_dimq = self.wq(x).view(B, T, H, hd).transpose(1, 2)          # [B, H, T, hd]k = self.wk(x).view(B, T, HKV, hd).transpose(1, 2)        # [B, HKV, T, hd]v = self.wv(x).view(B, T, HKV, hd).transpose(1, 2)        # [B, HKV, T, hd]# positions for RoPE: absolute indices in the sequence (cache length .. cache length+T-1)pos = torch.arange(start_pos, start_pos + T, device=x.device)q = self.rope(q.transpose(2, 3), pos).transpose(2, 3)     # rotate along last dimk = self.rope(k.transpose(2, 3), pos).transpose(2, 3)# append to cacheif cache is not None:k, v = cache.append(k, v)                             # [B, HKV, T_total, hd]total_T = cache.lenelse:total_T = T# expand K/V to per-head using head_to_kv mapping# K/V: [B, HKV, T_total, hd] -> [B, H, T_total, hd]kv_index = self.head_to_kv.view(1, -1, 1, 1).expand(B, H, total_T, 1)K = k.transpose(1, 2)                                     # [B, T_total, HKV, hd]V = v.transpose(1, 2)K = K.gather(2, kv_index.expand(B, H, total_T, hd))       # [B, H, T_total, hd]V = V.gather(2, kv_index.expand(B, H, total_T, hd))K = K.transpose(1, 2)                                     # [B, H, T_total, hd]V = V.transpose(1, 2)# use PyTorch SDPA (enables FlashAttention when available)attn_mask = torch.triu(torch.full((T, total_T), float('-inf'), device=x.device), diagonal=1 + start_pos)y = F.scaled_dot_product_attention(q, K, V, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0)y = y.transpose(1, 2).contiguous().view(B, T, H * hd)y = self.wo(y)return y# -------------------------------
# SwiGLU MLP
# -------------------------------class SwiGLU(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()inner = int(cfg.dim * cfg.ffn_multiplier)self.w1 = nn.Linear(cfg.dim, inner, bias=cfg.bias)  # mainself.w2 = nn.Linear(cfg.dim, inner, bias=cfg.bias)  # gateself.w3 = nn.Linear(inner, cfg.dim, bias=cfg.bias)self.dropout = nn.Dropout(cfg.dropout)def forward(self, x: torch.Tensor) -> torch.Tensor:a = self.w1(x)b = self.w2(x)gated = a * torch.sigmoid(b) * b  # Swish(b) = b * sigmoid(b)return self.w3(self.dropout(gated))# -------------------------------
# Transformer Block (Pre-Norm)
# -------------------------------class TransformerBlock(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.attn_norm = RMSNorm(cfg.dim, cfg.norm_eps)self.attn = Attention(cfg)self.ffn_norm = RMSNorm(cfg.dim, cfg.norm_eps)self.ffn = SwiGLU(cfg)self.dropout = nn.Dropout(cfg.dropout)def forward(self, x: torch.Tensor, cache: Optional[KVCache] = None, start_pos: int = 0) -> torch.Tensor:h = self.attn(self.attn_norm(x), cache=cache, start_pos=start_pos)x = x + self.dropout(h)h2 = self.ffn(self.ffn_norm(x))x = x + self.dropout(h2)return x# -------------------------------
# LlamaModel
# -------------------------------class LlamaModel(nn.Module):def __init__(self, cfg: ModelConfig):super().__init__()self.cfg = cfgself.tok_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim)self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])self.norm = RMSNorm(cfg.dim, cfg.norm_eps)self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)if cfg.tie_embeddings:self.lm_head.weight = self.tok_embeddings.weight@torch.no_grad()def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 1.0,top_k: Optional[int] = None) -> torch.Tensor:"""Greedy/Top-k sampling with KV cache.input_ids: [B, T]Returns: [B, T+max_new_tokens]"""device = next(self.parameters()).deviceB, T = input_ids.shape# initialize caches per layercaches = [KVCache() for _ in range(self.cfg.n_layers)]cur_ids = input_idsout = [input_ids]start_pos = 0for step in range(max_new_tokens):x = self.tok_embeddings(cur_ids)for i, blk in enumerate(self.blocks):x = blk(x, cache=caches[i], start_pos=start_pos)x = self.norm(x)logits = self.lm_head(x)[:, -1, :]  # last token logitsif temperature != 1.0:logits = logits / max(1e-8, temperature)if top_k is not None:v, ix = torch.topk(logits, top_k)logits = torch.full_like(logits, float('-inf')).scatter_(1, ix, v)probs = F.softmax(logits, dim=-1)next_id = torch.multinomial(probs, num_samples=1)  # [B,1]out.append(next_id)cur_ids = next_idstart_pos += 1return torch.cat(out, dim=1)def forward(self, input_ids: torch.Tensor) -> torch.Tensor:"""Training forward without external cache (teacher forcing)."""B, T = input_ids.shapex = self.tok_embeddings(input_ids)# per-layer temporary caches to reuse attention mask/pos handlingcaches = [None] * self.cfg.n_layersfor i, blk in enumerate(self.blocks):x = blk(x, cache=caches[i], start_pos=0)  # no accumulation across steps in full teacher forcingx = self.norm(x)logits = self.lm_head(x)return logits# -------------------------------
# Example configs: Llama2-like vs Llama3-like
# -------------------------------LLAMA2_7B = ModelConfig(vocab_size=32000,dim=4096,n_layers=32,n_heads=32,n_kv_heads=32,         # MHA (no grouping)max_seq_len=4096,rope_base=10000.0,ffn_multiplier=2.6667,norm_eps=1e-6,dropout=0.0,bias=False,
)LLAMA3_8B_LIKE = ModelConfig(vocab_size=128000,      # Llama3 uses larger tokenizer; illustrative valuedim=4096,n_layers=32,n_heads=32,n_kv_heads=8,           # GQA: 4 Q-heads share 1 KV headmax_seq_len=8192,rope_base=10000.0,      # Scaling strategies can modify thisffn_multiplier=2.6667,norm_eps=1e-6,dropout=0.0,bias=False,
)# -------------------------------
# Quick usage (pseudo)
# -------------------------------
if __name__ == "__main__":device = 'cuda' if torch.cuda.is_available() else 'cpu'# Pick a configcfg = LLAMA3_8B_LIKEmodel = LlamaModel(cfg).to(device)model.eval()# dummy idsx = torch.randint(0, cfg.vocab_size, (1, 10), device=device)with torch.no_grad():y = model.generate(x, max_new_tokens=5, temperature=1.0, top_k=50)print("Generated shape:", y.shape)


文章转载自:

http://QLoCgnxW.ktcrr.cn
http://1jc8YjYO.ktcrr.cn
http://y1Lzhuur.ktcrr.cn
http://tdPsZb7q.ktcrr.cn
http://QLGtrLWH.ktcrr.cn
http://2eZaZdFX.ktcrr.cn
http://VWJKL1bi.ktcrr.cn
http://hEKfPAMx.ktcrr.cn
http://DHPOyrJA.ktcrr.cn
http://aNkPhLWU.ktcrr.cn
http://VwkmGzBz.ktcrr.cn
http://FRv5ngj1.ktcrr.cn
http://ba5xzBbk.ktcrr.cn
http://JXVvdePz.ktcrr.cn
http://Lo6Q8Z0E.ktcrr.cn
http://QxRQWMU6.ktcrr.cn
http://kHMEBXRu.ktcrr.cn
http://kD47R508.ktcrr.cn
http://VHRE9Dv3.ktcrr.cn
http://dkgGU1Up.ktcrr.cn
http://XU7nGjBh.ktcrr.cn
http://yRYNNxPe.ktcrr.cn
http://l5r7Dhr4.ktcrr.cn
http://gaMBBsjj.ktcrr.cn
http://bVJQEtQR.ktcrr.cn
http://JTP2hjE2.ktcrr.cn
http://uQ1EaMax.ktcrr.cn
http://YnUcAVA2.ktcrr.cn
http://VXetF9w6.ktcrr.cn
http://jvKD2qSZ.ktcrr.cn
http://www.dtcms.com/a/362569.html

相关文章:

  • Wifi开发上层学习1:实现一个wifi搜索以及打开的app
  • 零依赖每月工作计划备忘录:高效管理你的每一天
  • Qt 创建的C++ 桌面程序 学习笔记1
  • Elasticsearch创建索引分片和副本大小建议
  • iOS XML 处理利器:CNXMLParser 与 CNXMLDocument 深度解析
  • iOS15如何绕过MDM锁?详细图文教程教你搞定
  • 数据结构:基数排序 (Radix Sort)
  • uni-app iOS 性能监控与调试全流程:多工具协作的实战案例
  • Qt中QSettings的键值使用QDataStream进行存储
  • 【Vue2 ✨】Vue2 入门之旅(七):事件处理
  • 从spring MVC角度理解HTTP协议及Request-Response模式
  • 自学嵌入式第三十二天:网络编程-UDP
  • 基于单片机醉酒驾驶检测系统/酒精检测/防疲劳驾驶设计
  • Angular事件处理全攻略:从基础到进阶的完整指南
  • GEO 应用实践研讨会:探索行业新路径,激发企业新活力
  • IoT Power软件 -- 每次开启强制升级解决方法
  • DVWA靶场通关笔记-DOM型XSS(Impossible级别)
  • CentOS7.6
  • 基于Force-closure评估的抓取计算流程
  • gitlab中回退代码,CI / CD 联系运维同事处理
  • RAGFlow——知识库检索系统开发实战指南(包含聊天和Agent模式)
  • 微信小程序备忘
  • ResponseBodyEmitter介绍
  • HarmonyOS 鸿蒙系统自带的 SymbolGlyph 图标组件详解
  • 【学Python自动化】 8.1 Python 与 Rust 错误处理对比学习笔记
  • 拔河(蓝桥杯)(前缀和)
  • Docker CI/CD 自动化部署配置指南
  • 【Datawhale之Happy-LLM】3种常见的decoder-only模型——Github最火大模型原理与实践教程task07
  • C#---共享项目
  • 【C++变量和数据类型:从基础到高级】