当前位置: 首页 > news >正文

Mini DeepSeek-v3训练脚本学习

Mini DeepSeek-v3 训练脚本详细技术说明(脚本在文章最后)

📋 概述

这是一个实现了Mini DeepSeek-v3大语言模型的训练脚本,集成了多项先进的深度学习技术。该脚本支持自动GPU选择和分布式训练,适合在多GPU环境下训练Transformer模型。

🚀 快速开始

运行方式

# 方式1:自动选择GPU并启动分布式训练
python train_mini_deepseek.py# 方式2:手动指定GPU
CUDA_VISIBLE_DEVICES=1,4 torchrun --standalone --nproc_per_node=2 train_mini_deepseek.py

🏗️ 架构解析

1. GPU自动选择机制 (pick_top_gpus)

作用:自动选择显存最大的GPU进行训练

原理

  • 使用NVML库查询所有GPU的显存状态
  • 按空闲显存大小排序,选择前N个GPU
  • 如果NVML不可用,默认选择前N个GPU

通俗解释:就像在停车场找最空的停车位,程序会自动找到显存最充足的GPU来训练模型。

2. 模型配置 (CFG类)

class CFG:vocab = 32_000      # 词汇表大小:模型能理解多少个不同的词max_seq = 1_024     # 最大序列长度:一次能处理多长的文本d_model = 1_024     # 模型维度:每个词用多少个数字来表示n_layer = 6         # 层数:模型有多少层神经网络n_head = 16         # 注意力头数:同时关注多少个方面latent_k = 64       # 潜在空间维度mlp_mult = 4        # MLP倍数moe_expert = 2      # 专家数量

通俗解释:这就像定义一个大脑的结构参数 - 能记住多少词汇、能同时思考多长的句子、大脑有多少层等等。

🧠 核心算法详解

3. RMSNorm 归一化

传统LayerNorm问题:计算复杂,需要计算均值和方差

RMSNorm优势

  • 只计算RMS (Root Mean Square),更简单高效
  • 公式:x * rsqrt(mean(x²) + ε) * weight

通俗解释:想象你在调音响的音量,RMSNorm就是一个自动音量控制器,确保每层神经网络的"音量"都保持在合适的范围内,这样信息传递更稳定。

def forward(self, x):var = x.pow(2).mean(-1, keepdim=True)  # 计算平方的平均值x = x * torch.rsqrt(var + self.eps)    # 归一化return self.weight * x                  # 加权输出

4. RoPE 旋转位置编码

问题:Transformer如何知道词语在句子中的位置?

RoPE解决方案

  • 将位置信息编码为旋转角度
  • 通过复数旋转在高维空间中表示位置
  • 具有良好的外推性能

通俗解释:就像给每个词戴上一个特殊的"位置手环",手环会根据词的位置旋转不同的角度,这样模型就能知道每个词在句子中的确切位置。

def rope(x, pos):d = x.size(-1); half = d // 2freq = torch.arange(half, device=x.device) / halftheta = pos[:, None] / (10000 ** freq)  # 计算旋转角度cos, sin = theta.cos(), theta.sin()xe, xo = x[..., 0::2], x[..., 1::2]     # 分离奇偶维度# 应用旋转变换return torch.cat([xe * cos - xo * sin, xe * sin + xo * cos], -1)

5. MHLA (Multi-Head Latent Attention)

创新点:引入潜在空间的注意力机制

三个阶段

  1. Read阶段:从输入序列中读取信息到潜在空间
  2. Latent Self阶段:在潜在空间内进行自注意力
  3. Write阶段:将潜在空间的信息写回到输出序列

通俗解释:想象你在开会做笔记:

  • Read:把别人说的话记录到你的笔记本上
  • Latent Self:在脑海中整理和思考这些信息
  • Write:基于思考结果给出你的回应
def forward(self, x, pos):# Read: 从输入读取到潜在空间z1 = self._attn(self.qL(z0), self.kX(x), self.vX(x))# Latent Self: 潜在空间内自注意力z2 = self._attn(self.qS(z1), self.kS(z1), self.vS(z1))# Write: 从潜在空间写回输出y = self._attn(self.qX2(x), self.kL2(z2), self.vL2(z2))

6. MoE (Mixture of Experts)

核心思想:不是所有神经元都参与每次计算

工作原理

  1. Gate网络:决定激活哪个专家
  2. 专家网络:每个专家负责处理特定类型的输入
  3. 路由机制:根据输入特征选择最合适的专家

通俗解释:就像一个医院,不同的病人会被分配给不同专科的医生。Gate网络是分诊台,专家网络是各科医生,每个"病人"(输入数据)会被送到最合适的"医生"(专家)那里处理。

def forward(self, x):route = self.gate(x).softmax(-1)    # 计算路由概率idx = route.argmax(-1)              # 选择最佳专家out = torch.zeros_like(x)for i, exp in enumerate(self.experts):m = idx == i                    # 找到分配给专家i的数据if m.any(): out[m] = exp(x[m])         # 专家处理对应数据return out

7. SwiGLU 激活函数

组合设计:Swish + GLU (Gated Linear Unit)

公式SwiGLU(x) = Swish(W1(x)) ⊗ W2(x)

优势

  • 结合了Swish的平滑特性
  • 加入了门控机制增强表达能力

通俗解释:像一个智能开关,不仅能控制信号的强弱(Swish部分),还能决定哪些信号可以通过(门控部分)。

def forward(self, x): return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))#              ↑ Swish激活        ↑ 门控机制   ↑ 输出变换

🔧 训练流程详解

8. 分布式训练设置

初始化过程

  1. 获取本地GPU编号 (LOCAL_RANK)
  2. 设置当前设备
  3. 初始化进程组 (nccl后端)
  4. 将模型包装为DDP

通俗解释:就像组织一个团队项目,每个GPU就是一个团队成员,需要先分配任务、建立通信机制,然后协同工作。

9. 学习率调度

Warmup + 线性递减策略

def lr(it):if it < CFG.warmup: return CFG.lr * it / CFG.warmup      # 预热阶段:逐渐增加return CFG.lr * (1 - it / CFG.total_step)  # 训练阶段:线性递减

通俗解释:就像开车一样,刚开始要慢慢加速(warmup),然后在行程接近结束时逐渐减速,这样能让模型训练更稳定。

10. 训练循环

核心步骤

  1. 前向传播:输入数据,计算预测结果
  2. 计算损失:比较预测和真实结果
  3. 反向传播:计算梯度
  4. 梯度裁剪:防止梯度爆炸
  5. 参数更新:优化模型参数
for x, y in dl:logits = model(x)[:, :-1].reshape(-1, CFG.vocab)          # 前向传播loss = nn.functional.cross_entropy(logits, y[:, 1:].reshape(-1))  # 计算损失loss.backward()                                            # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)   # 梯度裁剪opt.step()                                                 # 参数更新

📊 性能优化要点

内存优化

  • 梯度检查点:牺牲计算时间换取内存空间
  • 混合精度训练:使用FP16减少内存占用
  • 优化器状态管理set_to_none=True释放内存

计算优化

  • Fused AdamW:融合优化器操作减少kernel启动
  • Pin Memory:加速CPU到GPU的数据传输
  • 非阻塞传输non_blocking=True并行化数据传输

🎯 实际应用建议

硬件要求

  • 最低配置:2张RTX 3090 (24GB显存)
  • 推荐配置:2张RTX 4090 (24GB显存)
  • 内存:至少32GB系统内存

调参建议

  1. 学习率:根据batch size调整,遵循线性缩放规则
  2. Warmup步数:通常设为总训练步数的5-10%
  3. 批次大小:根据显存容量调整,保证梯度稳定

常见问题

  1. OOM (Out of Memory):减少batch_size或max_seq_len
  2. 训练不稳定:检查学习率设置和梯度裁剪阈值
  3. 收敛缓慢:调整warmup策略和学习率调度

🔍 代码扩展建议

功能增强

  • 添加验证集评估
  • 实现模型检查点保存/加载
  • 集成Wandb等实验跟踪工具
  • 支持更多数据格式

性能提升

  • 实现动态batch size
  • 添加梯度累积功能
  • 支持更多优化器选择
  • 集成Flash Attention

这个脚本展示了现代大语言模型训练中的多项前沿技术,是学习和研究Transformer架构的优秀参考实现。

mini_ds.py

# train_mini_deepseek.py ②GPU ⾃动挑卡版
"""
Mini DeepSeek-v3, auto-pick 2 GPUs with max free memory.Run simply:python train_mini_deepseek.py         # auto pick & spawn
or:CUDA_VISIBLE_DEVICES=1,4 torchrun --standalone --nproc_per_node=2 train_mini_deepseek.py
"""import os
import random
import subprocess
import sysimport math
from torch.distributed.elastic.multiprocessing.errors import record# ---------- 0. GPU AUTO-PICK & SELF-SPAWN ----------
def pick_top_gpus(num=2):"""return gpu indices with largest free memory"""try:import pynvml, torchpynvml.nvmlInit()infos = []for i in range(torch.cuda.device_count()):h = pynvml.nvmlDeviceGetHandleByIndex(i)free_mem = pynvml.nvmlDeviceGetMemoryInfo(h).freeinfos.append((free_mem, i))pynvml.nvmlShutdown()infos.sort(reverse=True)  # by free memoryreturn [idx for _, idx in infos[:num]]except Exception:# NVML 失效或 CUDA 不可用 → 默认选前 num 个return list(range(num))if "LOCAL_RANK" not in os.environ:gpu_ids = pick_top_gpus(2)print(f'gpu_ids: {gpu_ids}')os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))print(f"[AutoPick] Use GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")# 重新启动自身为 2 个分布式进程cmd = ["torchrun", "--standalone", f"--nproc_per_node={len(gpu_ids)}", sys.argv[0], *sys.argv[1:]]subprocess.check_call(cmd)sys.exit(0)# ---------- 1. 之后才 import torch 及重型包 ----------
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDPif not hasattr(nn, "RMSNorm"):class RMSNorm(nn.Module):def __init__(self, d, eps=1e-5):super().__init__()self.weight = nn.Parameter(torch.ones(d))self.eps = epsdef forward(self, x):# x: (B, N, d)var = x.pow(2).mean(-1, keepdim=True)x = x * torch.rsqrt(var + self.eps)return self.weight * xnn.RMSNorm = RMSNorm  # 注册到 torch.nn 里# ---------- 2. 配置 ----------
class CFG:vocab = 32_000;max_seq = 1_024d_model = 1_024;n_layer = 6;n_head = 16latent_k = 64mlp_mult = 4;moe_expert = 2lr = 3e-4;warmup = 100;total_step = 1_000batch = 4;seed = 42torch.manual_seed(CFG.seed);
random.seed(CFG.seed)# ---------- 3. 数据 ----------
class RandomDataset(Dataset):def __len__(self): return 10_000_000def __getitem__(self, idx):x = torch.randint(0, CFG.vocab, (CFG.max_seq,))y = torch.roll(x, -1)return x, y# ---------- 4. RoPE ----------
def rope(x, pos):d = x.size(-1);half = d // 2freq = torch.arange(half, device=x.device) / halftheta = pos[:, None] / (10000 ** freq)cos, sin = theta.cos(), theta.sin()xe, xo = x[..., 0::2], x[..., 1::2]return torch.cat([xe * cos - xo * sin, xe * sin + xo * cos], -1)# ---------- 5. MHLA ----------
class MHLA(nn.Module):def __init__(self):super().__init__()d, h, k = CFG.d_model, CFG.n_head, CFG.latent_kself.h, self.k, self.d = h, k, ddef lin(): return nn.Linear(d, d, bias=False)self.qL, self.kX, self.vX = lin(), lin(), lin()self.qS, self.kS, self.vS = lin(), lin(), lin()self.qX2, self.kL2, self.vL2 = lin(), lin(), lin()self.out = lin()self.latent = nn.Parameter(torch.randn(1, k, d) / math.sqrt(d))self.n1 = nn.RMSNorm(d);self.n2 = nn.RMSNorm(d);self.n3 = nn.RMSNorm(d)def _split(self, x):B, N, _ = x.shapereturn x.view(B, N, self.h, -1).permute(0, 2, 1, 3).reshape(B * self.h, N, -1)def _merge(self, x, B, N):  # inversereturn x.view(B, self.h, N, -1).permute(0, 2, 1, 3).reshape(B, N, self.d)def _attn(self, q, k, v):s = (q @ k.transpose(-2, -1)) / math.sqrt(q.size(-1))return (s.softmax(-1) @ v)def forward(self, x, pos):B, N, _ = x.shapez0 = self.latent.expand(B, -1, -1)# readz1 = self._attn(self._split(self.qL(z0)),self._split(self.kX(x)), self._split(self.vX(x)))z1 = self._merge(z1, B, self.k);z1 = self.n1(z0 + self.out(z1))# latent selfz2 = self._attn(self._split(self.qS(z1)),self._split(self.kS(z1)), self._split(self.vS(z1)))z2 = self._merge(z2, B, self.k);z2 = self.n2(z1 + self.out(z2))# writey = self._attn(self._split(self.qX2(x)),self._split(self.kL2(z2)), self._split(self.vL2(z2)))y = self._merge(y, B, N);y = self.n3(x + self.out(y))return y# ---------- 6. MoE FeedForward ----------
class SwiGLU(nn.Module):def __init__(self, d_in, d_hidden):super().__init__()self.w1 = nn.Linear(d_in, d_hidden, False)self.w2 = nn.Linear(d_in, d_hidden, False)self.w3 = nn.Linear(d_hidden, d_in, False)def forward(self, x): return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))class MoE(nn.Module):def __init__(self):super().__init__()d, h = CFG.d_model, CFG.mlp_mult * CFG.d_modelself.experts = nn.ModuleList([SwiGLU(d, h) for _ in range(CFG.moe_expert)])self.gate = nn.Linear(d, CFG.moe_expert, False)def forward(self, x):route = self.gate(x).softmax(-1)idx = route.argmax(-1)out = torch.zeros_like(x)for i, exp in enumerate(self.experts):m = idx == iif m.any(): out[m] = exp(x[m])return out# ---------- 7. Transformer Block ----------
class Block(nn.Module):def __init__(self, i):super().__init__()self.attn = MHLA() if i % 2 else nn.MultiheadAttention(CFG.d_model, CFG.n_head, batch_first=True)self.norm = nn.RMSNorm(CFG.d_model)self.ffn = MoE()# Pre-compute causal mask for MultiheadAttentionif not isinstance(self.attn, MHLA):self.register_buffer('causal_mask',torch.triu(torch.ones(CFG.max_seq, CFG.max_seq) * float('-inf'), diagonal=1))def forward(self, x, pos):if isinstance(self.attn, nn.MultiheadAttention):q = k = v = rope(x, pos)seq_len = x.size(1)# Use the pre-computed causal mask, truncated to current sequence lengthmask = self.causal_mask[:seq_len, :seq_len]a, _ = self.attn(q, k, v, need_weights=False, attn_mask=mask)else:a = self.attn(rope(x, pos), pos)x = x + ax = x + self.ffn(self.norm(x))return x# ---------- 8. Model ----------
class MiniDeepSeek(nn.Module):def __init__(self):super().__init__()self.embed = nn.Embedding(CFG.vocab, CFG.d_model)self.blocks = nn.ModuleList([Block(i) for i in range(CFG.n_layer)])self.ln_f = nn.RMSNorm(CFG.d_model)self.head = nn.Linear(CFG.d_model, CFG.vocab, False)def forward(self, idx):pos = torch.arange(idx.size(1), device=idx.device)h = self.embed(idx)for blk in self.blocks: h = blk(h, pos)return self.head(self.ln_f(h))# ---------- 9. Train ----------
@record
def main():local_rank = int(os.environ["LOCAL_RANK"])torch.cuda.set_device(local_rank)torch.distributed.init_process_group("nccl")model = MiniDeepSeek().cuda()model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)opt = torch.optim.AdamW(model.parameters(), lr=CFG.lr, fused=True)ds = RandomDataset()sampler = torch.utils.data.DistributedSampler(ds)dl = DataLoader(ds, batch_size=CFG.batch, sampler=sampler,pin_memory=True, num_workers=2)def lr(it):if it < CFG.warmup: return CFG.lr * it / CFG.warmupreturn CFG.lr * (1 - it / CFG.total_step)model.train()step = 0for x, y in dl:step += 1sampler.set_epoch(step)x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)logits = model(x)[:, :-1].reshape(-1, CFG.vocab)loss = nn.functional.cross_entropy(logits, y[:, 1:].reshape(-1))loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)for g in opt.param_groups: g["lr"] = lr(step)opt.step()opt.zero_grad(set_to_none=True)if step % 50 == 0 and torch.distributed.get_rank() == 0:print(f"step {step}/{CFG.total_step} loss {loss.item():.4f}")if step >= CFG.total_step: breaktorch.distributed.destroy_process_group()if __name__ == "__main__":main()
# CUDA_VISIBLE_DEVICES=5,7 python -m torch.distributed.run --nproc_per_node=2  mini_ds.py

运行结果

执行脚本CUDA_VISIBLE_DEVICES=5,7 python -m torch.distributed.run --nproc_per_node=2 mini_ds.py

在这里插入图片描述

在这里插入图片描述

相关文章:

  • 【k8s】阿里云ACK服务中GPU实例部署问题
  • AutoGLM沉思版:智能体推理的Deep Research探索
  • python从环境变量和配置文件中获取配置参数
  • 【面板数据】A股上市公司注册地所在地数据集(1991-2023年)
  • 【免费分享】GWO-BP-AdaBoost预测!灰狼优化、人工神经网络与AdaBoost集成学习算法预测研究
  • 梨泛转录组-文献精读145
  • 基于MATLAB的车牌检测系统:传统图像处理与深度学习的创新融合
  • 使用GpuGeek训练图像分类器:从入门到精通
  • Python实现下载监控工具:自动检测并移动下载文件
  • 计算机视觉与深度学习 | 低照度图像增强算法综述(开源链接,原理,公式,代码)
  • Day53 Python打卡训练营
  • Python Day50
  • 04 - CoordAttention模块
  • Python图片格式转换工具深度解析[附源码】
  • 完整强化学习教程:基于4x4网格世界的智能体探索之旅(一)
  • 2025-06-13【视频处理】基于视频内容转场进行分割
  • 动态规划算法的欢乐密码(二):路径问题
  • Spring Cloud Gateway + JWT 单点登录实现方案(无独立的认证服务器)
  • 最新 Python-PLAXIS 自动化建模技术与典型岩土工程案例实践应用
  • 搭建网站应该怎样选择服务器?
  • 珠海市企业网站制作平台/域名注册万网
  • 注册域名后网站建设/北京seo推广系统
  • 网站导航设计模板源码/市场营销计划方案
  • 做外贸怎么打开国外网站/武汉网站seo公司
  • 西山区城市建设局网站/seo运营招聘
  • 做英文网站常用的字体/外贸推广方式都有哪些