DeepSeek算法学习笔记
学习和总结deepseek优雅的算法设计思想,并提供基础知识回顾,帮助大家快速搞懂deepseek算法思想。
文章目录
- 1 模型架构
- MoE: 混合专家
- KV cache
- MLA: multi-head latent attention
- 2 训练方法
- RL基础
- 经典 RL 算法介绍
- 用于 LLM 训练的 RL 算法
- GRPO 算法
- RL 替代方案
1 模型架构
MoE: 混合专家
在 Transformer attention 模块中的前馈网络 FFNN 部分,改为多个 FFNNs,由路由门控网络选择激活哪个专家,从而减少冗余
KV cache
由于语言序列输入预测下一个token时,前面的token对应的 K V 会被重复计算,因此缓存计算过的KV
MLA: multi-head latent attention
参考视频讲解:https://www.youtube.com/watch?v=0VLAoVGf_74
由于KV cache占内存高达400G,提出思想:将K,V投影到低维空间进行计算attn,最后再上采样
推导过程和结果:所有transformer层共享一个将X下采样的 L_kv 矩阵,每个层有自己的低维Wk,Wv,最后用Wo上采样,可以大大降低KV cache所需存储,且表现更好
2 训练方法
RL基础
- R: 当前reward
- return: 当前及未来reward总和
- Q(s,a):状态-动作价值,按照当前策略时,当前状态下采用当前动作可获得的未来return的期望
- V(s):状态价值,按照当前策略时,当前状态下所有动作Q值的期望
- A(s,a):优势,当前状态下采用当前动作的好处,A(s,a)=Q(s,a)-V(s)
- 贝尔曼方程:当前状态价值=【当前状态的奖励+下一个状态的价值】的期望
- 优势估计方法:
- 蒙特卡洛法MC:对完整轨迹采样计算回报,不适合在线学习;偏差小,但方差较大
- 时序差分法TD:用一步预测来估计Q:A=rt+γV(st+1)−V(t)A=r_t+\gamma V(s_{t+1})-V(t)A=rt+γV(st+1)−V(t) ,方差小但偏差大
- GAE广义优势估计:A=∑n=0∞(γλ)nδtA=\sum_{n=0}^{\infty}(\gamma \lambda)^n\delta_tA=n=0∑∞(γλ)nδt 结合了TD和MC,平衡方差和偏差
- λ=0\lambda=0λ=0, 即时序差分法
- λ=1\lambda=1λ=1, 即蒙特卡洛法
经典 RL 算法介绍
-
DQN
学习最优Q值预测,从而选择最优动作:a = argmaxQ*(s,a)
Q*(s, a) = R(s, a)+γ max(Q*(s’, a’))
假设当前网络预测结果是 Q(s,a),学习率为α,则Q网络更新公式为:
Q(s, a) <-- Q(s, a) + α[R(s, a)+γ max(Q*(s’, a’)) - Q(s, a)] -
DDPG
DQN的连续动作版本,学习Q的同时,再学习一个策略网络使Q最大化 -
Policy gradient
策略由网络预测,根据梯度上升法更新参数: θ←θ+α∇J(πθ)\theta \leftarrow \theta + \alpha \nabla J(\pi_\theta)θ←θ+α∇J(πθ)
可推导出,策略梯度可以用下式估计:
J(πθ)=Eπθ[R(πθ)]J(\pi_\theta) = E_{\pi_\theta} [R(\pi_\theta)]J(πθ)=Eπθ[R(πθ)]
∇J(πθ)=E(∑t=0T∇θlogπθ(at∣st)A(st,at))\nabla J(\pi_\theta) = E(\sum_{t=0}^T \nabla_\theta \log \pi_\theta (a_t|s_t) A(s_t, a_t))∇J(πθ)=E(t=0∑T∇θlogπθ(at∣st)A(st,at)) -
TRPO
参考1、参考2
-
重要性采样:基于分布q(x),估计分布p(x)下的函数f(x)期望
Ep(x)[f(x)]≐Eq(x)[p(x)q(x)f(x)]E_{p(x)}[f(x)] \doteq E_{q(x)}[\frac{p(x)}{q(x)} f(x)] Ep(x)[f(x)]≐Eq(x)[q(x)p(x)f(x)] -
advantage的期望可以用旧策略采样估计:
J(πθ)=Eπ[A(st,at)]=Eπold[ππoldA′(st,at)]J(\pi_\theta) = E_\pi [A(s_t, a_t)]= E_{\pi_{old}}[\frac{\pi}{\pi_{old}}A'(s_t, a_t)]J(πθ)=Eπ[A(st,at)]=Eπold[πoldπA′(st,at)]
并约束新旧策略KL散度,使得新旧策略不要差异太大,更新更稳定- KL散度:衡量两个概率分布的差异
- KL散度:衡量两个概率分布的差异
-
缺点:参数更新需要二阶优化方法,计算成本高
- PPO
- 类似TRPO思路, 但把KL散度加到惩罚项里,而不是作为约束
- 对重要性采样系数进行CLIP
- 可以用梯度下降法完成更新
- 利用GAE估计优势,提高样本利用率
- 网络架构:actor-critic,策略+价值网络
用于 LLM 训练的 RL 算法
- 状态:文本序列
- 动作:next token
- 奖励:预训练的RM reward model 对完整响应进行打分
- 训练步骤:
- SFT:使用监督问答数据,微调预训练LLM
- RM:训练奖励模型
- RL:使用PPO算法,结合人类反馈微调策略模型和价值模型
- 额外加入一个惩罚项,约束RL策略与SFT策略的KL散度
- 传统 GAE 回报计算:LLM的输出是完整应答,因此一句话结束时才有奖励,每个中间token的return都等于最终token的奖励(γ=1),PPO采用单条轨迹重要性采样计算A
GRPO 算法
- DeepSeek 的 GRPO(Group Relative Policy Optimization):
- 用多次采样同一输入的回答的平均奖励作为V的近似估计,提高样本利用率,提升微调效率
- 只需要一个actor模型,降低训练资源需求
- 使用KL正则项,防止策略过度偏离参考策略
- 更新时增加梯度方差约束项,提升训练稳定性
RL 替代方案
偏好排序对齐
2024 PRO: Preference ranking optimization for human alignmen
- RL目的是优化 LLM 符合 RM,RM 目的是对齐人类偏好,为什么不直接优化LLM 对齐人类偏好?
直接从人类偏好排序对齐的角度优化LLM