Mini DeepSeek-v3训练脚本学习
Mini DeepSeek-v3 训练脚本详细技术说明(脚本在文章最后)
📋 概述
这是一个实现了Mini DeepSeek-v3大语言模型的训练脚本,集成了多项先进的深度学习技术。该脚本支持自动GPU选择和分布式训练,适合在多GPU环境下训练Transformer模型。
🚀 快速开始
运行方式
# 方式1:自动选择GPU并启动分布式训练
python train_mini_deepseek.py# 方式2:手动指定GPU
CUDA_VISIBLE_DEVICES=1,4 torchrun --standalone --nproc_per_node=2 train_mini_deepseek.py
🏗️ 架构解析
1. GPU自动选择机制 (pick_top_gpus
)
作用:自动选择显存最大的GPU进行训练
原理:
- 使用NVML库查询所有GPU的显存状态
- 按空闲显存大小排序,选择前N个GPU
- 如果NVML不可用,默认选择前N个GPU
通俗解释:就像在停车场找最空的停车位,程序会自动找到显存最充足的GPU来训练模型。
2. 模型配置 (CFG类
)
class CFG:vocab = 32_000 # 词汇表大小:模型能理解多少个不同的词max_seq = 1_024 # 最大序列长度:一次能处理多长的文本d_model = 1_024 # 模型维度:每个词用多少个数字来表示n_layer = 6 # 层数:模型有多少层神经网络n_head = 16 # 注意力头数:同时关注多少个方面latent_k = 64 # 潜在空间维度mlp_mult = 4 # MLP倍数moe_expert = 2 # 专家数量
通俗解释:这就像定义一个大脑的结构参数 - 能记住多少词汇、能同时思考多长的句子、大脑有多少层等等。
🧠 核心算法详解
3. RMSNorm 归一化
传统LayerNorm问题:计算复杂,需要计算均值和方差
RMSNorm优势:
- 只计算RMS (Root Mean Square),更简单高效
- 公式:
x * rsqrt(mean(x²) + ε) * weight
通俗解释:想象你在调音响的音量,RMSNorm就是一个自动音量控制器,确保每层神经网络的"音量"都保持在合适的范围内,这样信息传递更稳定。
def forward(self, x):var = x.pow(2).mean(-1, keepdim=True) # 计算平方的平均值x = x * torch.rsqrt(var + self.eps) # 归一化return self.weight * x # 加权输出
4. RoPE 旋转位置编码
问题:Transformer如何知道词语在句子中的位置?
RoPE解决方案:
- 将位置信息编码为旋转角度
- 通过复数旋转在高维空间中表示位置
- 具有良好的外推性能
通俗解释:就像给每个词戴上一个特殊的"位置手环",手环会根据词的位置旋转不同的角度,这样模型就能知道每个词在句子中的确切位置。
def rope(x, pos):d = x.size(-1); half = d // 2freq = torch.arange(half, device=x.device) / halftheta = pos[:, None] / (10000 ** freq) # 计算旋转角度cos, sin = theta.cos(), theta.sin()xe, xo = x[..., 0::2], x[..., 1::2] # 分离奇偶维度# 应用旋转变换return torch.cat([xe * cos - xo * sin, xe * sin + xo * cos], -1)
5. MHLA (Multi-Head Latent Attention)
创新点:引入潜在空间的注意力机制
三个阶段:
- Read阶段:从输入序列中读取信息到潜在空间
- Latent Self阶段:在潜在空间内进行自注意力
- Write阶段:将潜在空间的信息写回到输出序列
通俗解释:想象你在开会做笔记:
- Read:把别人说的话记录到你的笔记本上
- Latent Self:在脑海中整理和思考这些信息
- Write:基于思考结果给出你的回应
def forward(self, x, pos):# Read: 从输入读取到潜在空间z1 = self._attn(self.qL(z0), self.kX(x), self.vX(x))# Latent Self: 潜在空间内自注意力z2 = self._attn(self.qS(z1), self.kS(z1), self.vS(z1))# Write: 从潜在空间写回输出y = self._attn(self.qX2(x), self.kL2(z2), self.vL2(z2))
6. MoE (Mixture of Experts)
核心思想:不是所有神经元都参与每次计算
工作原理:
- Gate网络:决定激活哪个专家
- 专家网络:每个专家负责处理特定类型的输入
- 路由机制:根据输入特征选择最合适的专家
通俗解释:就像一个医院,不同的病人会被分配给不同专科的医生。Gate网络是分诊台,专家网络是各科医生,每个"病人"(输入数据)会被送到最合适的"医生"(专家)那里处理。
def forward(self, x):route = self.gate(x).softmax(-1) # 计算路由概率idx = route.argmax(-1) # 选择最佳专家out = torch.zeros_like(x)for i, exp in enumerate(self.experts):m = idx == i # 找到分配给专家i的数据if m.any(): out[m] = exp(x[m]) # 专家处理对应数据return out
7. SwiGLU 激活函数
组合设计:Swish + GLU (Gated Linear Unit)
公式:SwiGLU(x) = Swish(W1(x)) ⊗ W2(x)
优势:
- 结合了Swish的平滑特性
- 加入了门控机制增强表达能力
通俗解释:像一个智能开关,不仅能控制信号的强弱(Swish部分),还能决定哪些信号可以通过(门控部分)。
def forward(self, x): return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))# ↑ Swish激活 ↑ 门控机制 ↑ 输出变换
🔧 训练流程详解
8. 分布式训练设置
初始化过程:
- 获取本地GPU编号 (
LOCAL_RANK
) - 设置当前设备
- 初始化进程组 (
nccl
后端) - 将模型包装为DDP
通俗解释:就像组织一个团队项目,每个GPU就是一个团队成员,需要先分配任务、建立通信机制,然后协同工作。
9. 学习率调度
Warmup + 线性递减策略:
def lr(it):if it < CFG.warmup: return CFG.lr * it / CFG.warmup # 预热阶段:逐渐增加return CFG.lr * (1 - it / CFG.total_step) # 训练阶段:线性递减
通俗解释:就像开车一样,刚开始要慢慢加速(warmup),然后在行程接近结束时逐渐减速,这样能让模型训练更稳定。
10. 训练循环
核心步骤:
- 前向传播:输入数据,计算预测结果
- 计算损失:比较预测和真实结果
- 反向传播:计算梯度
- 梯度裁剪:防止梯度爆炸
- 参数更新:优化模型参数
for x, y in dl:logits = model(x)[:, :-1].reshape(-1, CFG.vocab) # 前向传播loss = nn.functional.cross_entropy(logits, y[:, 1:].reshape(-1)) # 计算损失loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪opt.step() # 参数更新
📊 性能优化要点
内存优化
- 梯度检查点:牺牲计算时间换取内存空间
- 混合精度训练:使用FP16减少内存占用
- 优化器状态管理:
set_to_none=True
释放内存
计算优化
- Fused AdamW:融合优化器操作减少kernel启动
- Pin Memory:加速CPU到GPU的数据传输
- 非阻塞传输:
non_blocking=True
并行化数据传输
🎯 实际应用建议
硬件要求
- 最低配置:2张RTX 3090 (24GB显存)
- 推荐配置:2张RTX 4090 (24GB显存)
- 内存:至少32GB系统内存
调参建议
- 学习率:根据batch size调整,遵循线性缩放规则
- Warmup步数:通常设为总训练步数的5-10%
- 批次大小:根据显存容量调整,保证梯度稳定
常见问题
- OOM (Out of Memory):减少batch_size或max_seq_len
- 训练不稳定:检查学习率设置和梯度裁剪阈值
- 收敛缓慢:调整warmup策略和学习率调度
🔍 代码扩展建议
功能增强
- 添加验证集评估
- 实现模型检查点保存/加载
- 集成Wandb等实验跟踪工具
- 支持更多数据格式
性能提升
- 实现动态batch size
- 添加梯度累积功能
- 支持更多优化器选择
- 集成Flash Attention
这个脚本展示了现代大语言模型训练中的多项前沿技术,是学习和研究Transformer架构的优秀参考实现。
mini_ds.py
# train_mini_deepseek.py ②GPU ⾃动挑卡版
"""
Mini DeepSeek-v3, auto-pick 2 GPUs with max free memory.Run simply:python train_mini_deepseek.py # auto pick & spawn
or:CUDA_VISIBLE_DEVICES=1,4 torchrun --standalone --nproc_per_node=2 train_mini_deepseek.py
"""import os
import random
import subprocess
import sysimport math
from torch.distributed.elastic.multiprocessing.errors import record# ---------- 0. GPU AUTO-PICK & SELF-SPAWN ----------
def pick_top_gpus(num=2):"""return gpu indices with largest free memory"""try:import pynvml, torchpynvml.nvmlInit()infos = []for i in range(torch.cuda.device_count()):h = pynvml.nvmlDeviceGetHandleByIndex(i)free_mem = pynvml.nvmlDeviceGetMemoryInfo(h).freeinfos.append((free_mem, i))pynvml.nvmlShutdown()infos.sort(reverse=True) # by free memoryreturn [idx for _, idx in infos[:num]]except Exception:# NVML 失效或 CUDA 不可用 → 默认选前 num 个return list(range(num))if "LOCAL_RANK" not in os.environ:gpu_ids = pick_top_gpus(2)print(f'gpu_ids: {gpu_ids}')os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))print(f"[AutoPick] Use GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")# 重新启动自身为 2 个分布式进程cmd = ["torchrun", "--standalone", f"--nproc_per_node={len(gpu_ids)}", sys.argv[0], *sys.argv[1:]]subprocess.check_call(cmd)sys.exit(0)# ---------- 1. 之后才 import torch 及重型包 ----------
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDPif not hasattr(nn, "RMSNorm"):class RMSNorm(nn.Module):def __init__(self, d, eps=1e-5):super().__init__()self.weight = nn.Parameter(torch.ones(d))self.eps = epsdef forward(self, x):# x: (B, N, d)var = x.pow(2).mean(-1, keepdim=True)x = x * torch.rsqrt(var + self.eps)return self.weight * xnn.RMSNorm = RMSNorm # 注册到 torch.nn 里# ---------- 2. 配置 ----------
class CFG:vocab = 32_000;max_seq = 1_024d_model = 1_024;n_layer = 6;n_head = 16latent_k = 64mlp_mult = 4;moe_expert = 2lr = 3e-4;warmup = 100;total_step = 1_000batch = 4;seed = 42torch.manual_seed(CFG.seed);
random.seed(CFG.seed)# ---------- 3. 数据 ----------
class RandomDataset(Dataset):def __len__(self): return 10_000_000def __getitem__(self, idx):x = torch.randint(0, CFG.vocab, (CFG.max_seq,))y = torch.roll(x, -1)return x, y# ---------- 4. RoPE ----------
def rope(x, pos):d = x.size(-1);half = d // 2freq = torch.arange(half, device=x.device) / halftheta = pos[:, None] / (10000 ** freq)cos, sin = theta.cos(), theta.sin()xe, xo = x[..., 0::2], x[..., 1::2]return torch.cat([xe * cos - xo * sin, xe * sin + xo * cos], -1)# ---------- 5. MHLA ----------
class MHLA(nn.Module):def __init__(self):super().__init__()d, h, k = CFG.d_model, CFG.n_head, CFG.latent_kself.h, self.k, self.d = h, k, ddef lin(): return nn.Linear(d, d, bias=False)self.qL, self.kX, self.vX = lin(), lin(), lin()self.qS, self.kS, self.vS = lin(), lin(), lin()self.qX2, self.kL2, self.vL2 = lin(), lin(), lin()self.out = lin()self.latent = nn.Parameter(torch.randn(1, k, d) / math.sqrt(d))self.n1 = nn.RMSNorm(d);self.n2 = nn.RMSNorm(d);self.n3 = nn.RMSNorm(d)def _split(self, x):B, N, _ = x.shapereturn x.view(B, N, self.h, -1).permute(0, 2, 1, 3).reshape(B * self.h, N, -1)def _merge(self, x, B, N): # inversereturn x.view(B, self.h, N, -1).permute(0, 2, 1, 3).reshape(B, N, self.d)def _attn(self, q, k, v):s = (q @ k.transpose(-2, -1)) / math.sqrt(q.size(-1))return (s.softmax(-1) @ v)def forward(self, x, pos):B, N, _ = x.shapez0 = self.latent.expand(B, -1, -1)# readz1 = self._attn(self._split(self.qL(z0)),self._split(self.kX(x)), self._split(self.vX(x)))z1 = self._merge(z1, B, self.k);z1 = self.n1(z0 + self.out(z1))# latent selfz2 = self._attn(self._split(self.qS(z1)),self._split(self.kS(z1)), self._split(self.vS(z1)))z2 = self._merge(z2, B, self.k);z2 = self.n2(z1 + self.out(z2))# writey = self._attn(self._split(self.qX2(x)),self._split(self.kL2(z2)), self._split(self.vL2(z2)))y = self._merge(y, B, N);y = self.n3(x + self.out(y))return y# ---------- 6. MoE FeedForward ----------
class SwiGLU(nn.Module):def __init__(self, d_in, d_hidden):super().__init__()self.w1 = nn.Linear(d_in, d_hidden, False)self.w2 = nn.Linear(d_in, d_hidden, False)self.w3 = nn.Linear(d_hidden, d_in, False)def forward(self, x): return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))class MoE(nn.Module):def __init__(self):super().__init__()d, h = CFG.d_model, CFG.mlp_mult * CFG.d_modelself.experts = nn.ModuleList([SwiGLU(d, h) for _ in range(CFG.moe_expert)])self.gate = nn.Linear(d, CFG.moe_expert, False)def forward(self, x):route = self.gate(x).softmax(-1)idx = route.argmax(-1)out = torch.zeros_like(x)for i, exp in enumerate(self.experts):m = idx == iif m.any(): out[m] = exp(x[m])return out# ---------- 7. Transformer Block ----------
class Block(nn.Module):def __init__(self, i):super().__init__()self.attn = MHLA() if i % 2 else nn.MultiheadAttention(CFG.d_model, CFG.n_head, batch_first=True)self.norm = nn.RMSNorm(CFG.d_model)self.ffn = MoE()# Pre-compute causal mask for MultiheadAttentionif not isinstance(self.attn, MHLA):self.register_buffer('causal_mask',torch.triu(torch.ones(CFG.max_seq, CFG.max_seq) * float('-inf'), diagonal=1))def forward(self, x, pos):if isinstance(self.attn, nn.MultiheadAttention):q = k = v = rope(x, pos)seq_len = x.size(1)# Use the pre-computed causal mask, truncated to current sequence lengthmask = self.causal_mask[:seq_len, :seq_len]a, _ = self.attn(q, k, v, need_weights=False, attn_mask=mask)else:a = self.attn(rope(x, pos), pos)x = x + ax = x + self.ffn(self.norm(x))return x# ---------- 8. Model ----------
class MiniDeepSeek(nn.Module):def __init__(self):super().__init__()self.embed = nn.Embedding(CFG.vocab, CFG.d_model)self.blocks = nn.ModuleList([Block(i) for i in range(CFG.n_layer)])self.ln_f = nn.RMSNorm(CFG.d_model)self.head = nn.Linear(CFG.d_model, CFG.vocab, False)def forward(self, idx):pos = torch.arange(idx.size(1), device=idx.device)h = self.embed(idx)for blk in self.blocks: h = blk(h, pos)return self.head(self.ln_f(h))# ---------- 9. Train ----------
@record
def main():local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)torch.distributed.init_process_group("nccl")model = MiniDeepSeek().cuda()model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)opt = torch.optim.AdamW(model.parameters(), lr=CFG.lr, fused=True)ds = RandomDataset()sampler = torch.utils.data.DistributedSampler(ds)dl = DataLoader(ds, batch_size=CFG.batch, sampler=sampler,pin_memory=True, num_workers=2)def lr(it):if it < CFG.warmup: return CFG.lr * it / CFG.warmupreturn CFG.lr * (1 - it / CFG.total_step)model.train()step = 0for x, y in dl:step += 1sampler.set_epoch(step)x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)logits = model(x)[:, :-1].reshape(-1, CFG.vocab)loss = nn.functional.cross_entropy(logits, y[:, 1:].reshape(-1))loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)for g in opt.param_groups: g["lr"] = lr(step)opt.step()opt.zero_grad(set_to_none=True)if step % 50 == 0 and torch.distributed.get_rank() == 0:print(f"step {step}/{CFG.total_step} loss {loss.item():.4f}")if step >= CFG.total_step: breaktorch.distributed.destroy_process_group()if __name__ == "__main__":main()
# CUDA_VISIBLE_DEVICES=5,7 python -m torch.distributed.run --nproc_per_node=2 mini_ds.py
运行结果
执行脚本CUDA_VISIBLE_DEVICES=5,7 python -m torch.distributed.run --nproc_per_node=2 mini_ds.py