EFlat-LoRA 的严格数学推导
EFlat-LoRA 的严格数学推导
一、Flat-LoRA基础
首先明确 LoRA 的权重表示、SAM 的优化框架,以及二者直接结合的核心矛盾,为 Flat-LoRA 的推导奠定基础。
1. LoRA 的权重更新公式
LoRA 通过低秩分解表示预训练模型的权重变化,对于任意层的预训练权重 $ W_0 \in \mathbb{R}^{n \times m} $,其微调后的权重为:
W=W0+ΔW=W0+sBA(6) W = W_0 + \Delta W = W_0 + sBA \tag{6}W=W0+ΔW=W0+sBA(6)
其中:
- sss 为缩放因子(控制低秩更新的幅度);
- B∈Rn×rB \in \mathbb{R}^{n \times r}B∈Rn×r、A∈Rr×mA \in \mathbb{R}^{r \times m}A∈Rr×m 为 LoRA 待优化的低秩矩阵(r≪min(n,m)r \ll \min(n,m)r≪min(n,m),保证参数效率);
- 训练中 W0W_0W0 冻结,仅更新AAA(Kaiming 初始化)和 BBB(零初始化)。
2. SAM 与 LoRA 直接结合的矛盾(Naive 方案)
SAM 的核心是通过“极小化最大扰动损失”寻找平坦极小值,其优化目标为:
minwmax∥ε∥≤ρL(w+ε) \min_{w} \max_{\|\varepsilon\| \leq \rho} L(w+\varepsilon)wmin∥ε∥≤ρmaxL(w+ε)
若直接将 SAM 与 LoRA 结合,需对 AAA、BBB 分别加扰动EA∈Rr×mE^A \in \mathbb{R}^{r \times m}EA∈Rr×m、EB∈Rn×rE^B \in \mathbb{R}^{n \times r}EB∈Rn×r(均满足 Frobenius 范数约束∥EA∥F≤ρ\|E^A\|_F \leq \rho∥EA∥F≤ρ、∥EB∥F≤ρ\|E^B\|_F \leq \rho∥EB∥F≤ρ),优化目标变为:
minA,Bmax∥EA∥F≤ρ∥EB∥F≤ρL(W0+s(B+EB)(A+EA))(7)\min_{A,B} \max_{\substack{\|E^A\|_F \leq \rho \\ \|E^B\|_F \leq \rho}} L\left(W_0 + s(B+E^B)(A+E^A)\right) \tag{7}A,Bmin∥EA∥F≤ρ∥EB∥F≤ρmaxL(W0+s(B+EB)(A+EA))(7)
核心矛盾:
- 双扰动EAE^AEA、EBE^BEB 互相干扰,导致低秩子空间计算的“最大损失”与全参数空间(W∈Rn×mW \in \mathbb{R}^{n \times m}W∈Rn×m)的“最大损失”不一致,无法精准找到平坦极小值;
- 全参数梯度 ∇LW(W)\nabla L_W(W)∇LW(W)未知(LoRA 仅优化 AAA、BBB),无法直接套用 SAM 的扰动计算逻辑。
二、Flat-LoRA 的核心推导:扰动重参数化
Flat-LoRA 的核心是将“全参数空间的扰动”通过数学变换,转化为“单一低秩矩阵(B)的扰动”,既对齐全参数空间的损失逻辑,又避免双扰动干扰。
1. 步骤1:全参数空间的扰动建模
首先跳出 LoRA 的低秩子空间,直接在全参数空间定义 SAM 优化目标(目标是找到全参数空间的最大扰动损失):
minA,Bmax∥EW∥F≤ρL(W0+sBA+EW)(8) \min_{A,B} \max_{\|E^W\|_F \leq \rho} L\left(W_0 + sBA + E^W\right) \tag{8} A,Bmin∥EW∥F≤ρmaxL(W0+sBA+EW)(8)
其中 EW∈Rn×mE^W \in \mathbb{R}^{n \times m}EW∈Rn×m是全参数空间的扰动(Frobenius 范数约束),W=W0+sBAW = W_0 + sBAW=W0+sBA 是 LoRA 微调后的全权重。
要解决式 (8) 的“min-max”问题,需先求内层的“max”(即找到最优扰动 E^W\hat{E}^WE^W,使L(W+EW)L(W + E^W)L(W+EW) 最大)。类比 SAM 的泰勒展开近似:
对 L(W+EW)L(W + E^W)L(W+EW) 在 WWW 处一阶泰勒展开,忽略高阶项:
L(W+EW)≈L(W)+Vector(EW)⊤⋅Vector(∇LW(W))L(W + E^W) \approx L(W) + \text{Vector}(E^W)^\top \cdot \text{Vector}\left(\nabla L_W(W)\right)L(W+EW)≈L(W)+Vector(EW)⊤⋅Vector(∇LW(W))
其中 ∇LW(W)=∂L∂W\nabla L_W(W) = \frac{\partial L}{\partial W}∇LW(W)=∂W∂L 是损失对全权重WWW 的梯度,Vector(⋅)\text{Vector}(·)Vector(⋅)表示“矩阵向量化”操作(将 n×mn \times mn×m矩阵转为 nm×1nm \times 1nm×1向量)。
根据柯西-施瓦茨不等式,当 Vector(EW)\text{Vector}(E^W)Vector(EW) 与 Vector(∇LW(W))\text{Vector}(\nabla L_W(W))Vector(∇LW(W)) 同方向时,内积最大。因此最优扰动的向量化形式为:
ε^w=ρ⋅sign(gw)⋅gw∥gw∥(9)\hat{\varepsilon}^w = \rho \cdot \text{sign}(g^w) \cdot \frac{g^w}{\|g^w\|} \tag{9}ε^w=ρ⋅sign(gw)⋅∥gw∥gw(9)
其中 gw=Vector(∇LW(W))g^w = \text{Vector}\left(\nabla L_W(W)\right)gw=Vector(∇LW(W)),sign(⋅)\text{sign}(·)sign(⋅)为符号函数,∥⋅∥\|·\|∥⋅∥ 为 L2L_2L2范数。
将向量化结果还原为矩阵,得到全参数空间的最优扰动:
E^W=Matrix(ε^w)\hat{E}^W = \text{Matrix}\left(\hat{\varepsilon}^w\right) E^W=Matrix(ε^w)
其中 Matrix(⋅)\text{Matrix}(·)Matrix(⋅) 是“向量矩阵化”操作(与 Vector(⋅)\text{Vector}(·)Vector(⋅) 互为逆操作)。
2. 步骤2:用 LoRA 梯度近似全参数梯度∇LW(W)\nabla L_W(W)∇LW(W)
式 (9) 中的 ∇LW(W)\nabla L_W(W)∇LW(W)是未知的(LoRA 仅优化 AAA、BBB,不直接计算全权重梯度),因此需要通过 LoRA 可计算的梯度(∇LA\nabla L_A∇LA、∇LB\nabla L_B∇LB)结合伪逆 近似。
根据链式法则,推导 ∇LA\nabla L_A∇LA、∇LB\nabla L_B∇LB 与 ∇LW(W)\nabla L_W(W)∇LW(W)的关系:
- 对 AAA 求梯度:∇LA=∂L∂A=∂L∂W⋅∂W∂A=s⋅B⊤⋅∇LW(W)\nabla L_A = \frac{\partial L}{\partial A} = \frac{\partial L}{\partial W} \cdot \frac{\partial W}{\partial A} = s \cdot B^\top \cdot \nabla L_W(W)∇LA=∂A∂L=∂W∂L⋅∂A∂W=s⋅B⊤⋅∇LW(W);
- 对BBB 求梯度:∇LB=∂L∂B=∂L∂W⋅∂W∂B=s⋅∇LW(W)⋅A⊤\nabla L_B = \frac{\partial L}{\partial B} = \frac{\partial L}{\partial W} \cdot \frac{\partial W}{\partial B} = s \cdot \nabla L_W(W) \cdot A^\top∇LB=∂B∂L=∂W∂L⋅∂B∂W=s⋅∇LW(W)⋅A⊤。
对上述两式两边分别左乘/右乘伪逆(解决矩阵不可逆问题),解出∇LW(W)\nabla L_W(W)∇LW(W) 的两种近似:
- 由∇LB\nabla L_B∇LB推导:两边右乘(A⊤)+(A^\top)^+(A⊤)+(A⊤A^\topA⊤ 的伪逆),得:
nablaLW(W)≈1s⋅∇LB⋅(A⊤)+(10)nabla L_W(W) \approx \frac{1}{s} \cdot \nabla L_B \cdot (A^\top)^+ \tag{10}nablaLW(W)≈s1⋅∇LB⋅(A⊤)+(10) - 由$\nabla L_A $ 推导:两边左乘 (B⊤)+(B^\top)^+(B⊤)+(B⊤B^\topB⊤ 的伪逆),得:
∇LW(W)≈1s⋅(B⊤)+⋅∇LA(11)\nabla L_W(W) \approx \frac{1}{s} \cdot (B^\top)^+ \cdot \nabla L_A \tag{11}∇LW(W)≈s1⋅(B⊤)+⋅∇LA(11)
为提升近似精度,取两式的平均作为最终的全参数梯度估计:
∇LW(W)‾=0.5⋅[1s∇LB(A⊤)++1s(B⊤)+∇LA](12)\overline{\nabla L_W(W)} = 0.5 \cdot \left[ \frac{1}{s} \nabla L_B (A^\top)^+ + \frac{1}{s} (B^\top)^+ \nabla L_A \right] \tag{12}∇LW(W)=0.5⋅[s1∇LB(A⊤)++s1(B⊤)+∇LA](12)
3. 步骤3:将全参数扰动转化为单一低秩矩阵(B)的扰动
目标是让 LoRA 低秩子空间的扰动损失,与全参数空间的最大损失完全对齐(即式 (8) 的内层“max”等于低秩扰动的损失):
L(W0+s(B+EB)A)=max∥EW∥F≤ρL(W0+sBA+EW)(14)L\left(W_0 + s(B+E^B)A\right) = \max_{\|E^W\|_F \leq \rho} L\left(W_0 + sBA + E^W\right) \tag{14}L(W0+s(B+EB)A)=∥EW∥F≤ρmaxL(W0+sBA+EW)(14)
将步骤1得到的 E^W\hat{E}^WE^W 代入右边(全参数最大损失对应 $E^W = \hat{E}^W $),展开左边:
W0+sBA+sEBA=W0+sBA+E^WW_0 + sBA + sE^B A = W_0 + sBA + \hat{E}^WW0+sBA+sEBA=W0+sBA+E^W
两边消去W0+sBAW_0 + sBAW0+sBA,得sEBA≈E^WsE^B A \approx \hat{E}^WsEBA≈E^W。对两边右乘A+A^+A+(AAA 的伪逆),最终解出BBB的扰动:
EB≈1s⋅E^W⋅A+(15)E^B \approx \frac{1}{s} \cdot \hat{E}^W \cdot A^+ \tag{15}EB≈s1⋅E^W⋅A+(15)
选择 B 而非 A 加扰动的原因:AAA因初始化特性捕捉“跨任务通用特征”,BBB捕捉“任务特定特征”;对BBB加扰动可适配不同任务需求,避免通用特征被干扰。
三、Flat-LoRA 的平衡性质推导(定理1)
为验证 Flat-LoRA 的优化稳定性,文档定义“平衡度”(反映AAA、BBB 参数更新的均衡性),并推导其随训练的变化规律。
1. 平衡度定义
设 xt=Vector(Bt)x_t = \text{Vector}(B_t)xt=Vector(Bt)(BtB_tBt 为 t 步的低秩矩阵),yt=Vector(At)y_t = \text{Vector}(A_t)yt=Vector(At),平衡度定义为:
Bt=12(∥xt∥2−∥yt∥2)B_t = \frac{1}{2} \left( \|x_t\|^2 - \|y_t\|^2 \right)Bt=21(∥xt∥2−∥yt∥2)
2. Flat-LoRA 的参数更新过程
Flat-LoRA 的更新分“扰动→梯度计算→参数更新”三步:
- 扰动 xtx_txt:x~t=xt+ρ⋅1s⋅Gt∥Gt∥⋅yt+\tilde{x}_t = x_t + \rho \cdot \frac{1}{s} \cdot \frac{G_t}{\|G_t\|} \cdot y_t^+x~t=xt+ρ⋅s1⋅∥Gt∥Gt⋅yt+(y~t=yt\tilde{y}_t = y_ty~t=yt,Gt=∇L(xtyt⊤)G_t = \nabla L(x_t y_t^\top)Gt=∇L(xtyt⊤) 为原参数点的全梯度,yt+y_t^+yt+ 为yty_tyt 的伪逆);
- 计算扰动点梯度:gx‾t=G~ty‾tg_{\overline{x}_t} = \tilde{G}_t \overline{y}_tgxt=G~tyt,gy‾t=G~t⊤x~tg_{\overline{y}_t} = \tilde{G}_t^\top \tilde{x}_tgyt=G~t⊤x~t(tildeGt=∇L(x~ty~t⊤)tilde{G}_t = \nabla L(\tilde{x}_t \tilde{y}_t^\top)tildeGt=∇L(x~ty~t⊤) 为扰动点的全梯度);
- 更新参数:xt+1=xt−ηgx‾tx_{t+1} = x_t - \eta g_{\overline{x}_t}xt+1=xt−ηgxt,yt+1=yt−ηgy‾ty_{t+1} = y_t - \eta g_{\overline{y}_t}yt+1=yt−ηgyt(η\etaη为学习率)。
3. 平衡度的导数约束(定理1)
对 BtB_tBt 关于训练步长 ttt求导(当 η→0\eta \to 0η→0 时,参数更新趋近连续流):
dBtdt=12(2xt⊤x˙t−2yt⊤y˙t)=xt⊤x˙t−yt⊤y˙t\frac{d B_t}{d t} = \frac{1}{2} \left( 2x_t^\top \dot{x}_t - 2y_t^\top \dot{y}_t \right) = x_t^\top \dot{x}_t - y_t^\top \dot{y}_tdtdBt=21(2xt⊤x˙t−2yt⊤y˙t)=xt⊤x˙t−yt⊤y˙t
代入 x˙t=−limη→0xt−xt+1η=−gx‾t\dot{x}_t = -\lim_{\eta \to 0} \frac{x_t - x_{t+1}}{\eta} = -g_{\overline{x}_t}x˙t=−limη→0ηxt−xt+1=−gxt、y˙t=−gy‾t\dot{y}_t = -g_{\overline{y}_t}y˙t=−gyt,结合 x~t\tilde{x}_tx~t 的扰动形式,最终通过范数约束推导得:
12d(∥xt∥2−∥yt∥2)dt∣≤∣ρ⋅1s⋅1∥yt∥⋅∥gx‾t∥∣(17)\left. \frac{1}{2} \frac{d \left( \|x_t\|^2 - \|y_t\|^2 \right)}{d t} \right| \leq \left| \rho \cdot \frac{1}{s} \cdot \frac{1}{\|y_t\|} \cdot \|g_{\overline{x}_t}\| \right| \tag{17} 21dtd(∥xt∥2−∥yt∥2)≤ρ⋅s1⋅∥yt∥1⋅∥gxt∥(17)
该式表明:Flat-LoRA 的平衡度会随训练逐渐降低(因 ρ\rhoρ 随训练衰减、∥gx‾t∥\|g_{\overline{x}_t}\|∥gxt∥ 因权重衰减减小),保证参数更新的均衡性,避免优化不稳定。
四、EFlat-LoRA推导前提
EFlat-LoRA是Flat-LoRA的效率优化版,需先明确Flat-LoRA的核心输出——B矩阵的实时扰动EtBE_t^BEtB:
由Flat-LoRA推导可知,全参数空间的最优扰动E^tW\hat{E}_t^WE^tW可转化为LoRA中B矩阵的实时扰动(AAA矩阵固定,无扰动EA=0E^A=0EA=0):
EtB=1s⋅E^tW⋅At+E_t^B = \frac{1}{s} \cdot \hat{E}_t^W \cdot A_t^+ EtB=s1⋅E^tW⋅At+
其中:sss为LoRA缩放因子,At+A_t^+At+为ttt步迭代时AAA矩阵的伪逆,E^tW\hat{E}_t^WE^tW为ttt步全参数空间的最优扰动(由Flat-LoRA的梯度近似得到)。
Flat-LoRA的痛点是:计算EtBE_t^BEtB需两次梯度计算(一次原参数点WtW_tWt,一次扰动参数点Wt+E^tWW_t+\hat{E}_t^WWt+E^tW),时间复杂度为O(2T)O(2T)O(2T)(TTT为LoRA的时间复杂度)。EFlat-LoRA通过“EMA估计EtBE_t^BEtB”将梯度计算次数降为1次,核心是用“历史扰动的加权平均”替代“实时扰动的重复计算”。
五、核心推导1:EMA扰动的数学定义与迭代更新
EFlat-LoRA的核心操作是用EMA扰动E^tB\hat{E}_t^BE^tB近似Flat-LoRA的实时扰动EtBE_t^BEtB,避免每次迭代重新计算EtBE_t^BEtB的梯度开销。
1. EMA扰动的定义
对B矩阵的实时扰动EtBE_t^BEtB(ttt步迭代时的真实扰动),用“指数移动平均”计算平滑后的扰动E^tB\hat{E}_t^BE^tB,迭代公式为:
E^tB=(1−β)⋅E^t−1B+β⋅EtB\hat{E}_t^B = (1-\beta) \cdot \hat{E}_{t-1}^B + \beta \cdot E_t^B E^tB=(1−β)⋅E^t−1B+β⋅EtB
其中:
- β∈(0,1)\beta \in (0,1)β∈(0,1)为动量系数(控制历史扰动的权重,通常取0.9~0.99,原文未指定具体值,需通过实验调优);
- E^t−1B\hat{E}_{t-1}^BE^t−1B为t−1t-1t−1步的EMA扰动(初始值E^0B=0\hat{E}_0^B=0E^0B=0,即首次迭代用实时扰动E1BE_1^BE1B);
- EtBE_t^BEtB为ttt步Flat-LoRA的实时扰动(由式(15)计算,仅在“更新EMA”时需计算1次,无需重复用于梯度下降)。
2. EMA扰动的物理意义
EMA的本质是对历史扰动进行“指数加权平滑”:近期扰动(EtBE_t^BEtB)权重为β\betaβ,远期扰动(E^t−1B\hat{E}_{t-1}^BE^t−1B)权重为(1−β)(1-\beta)(1−β),且权重随时间指数衰减(如β=0.9\beta=0.9β=0.9时,t−2t-2t−2步扰动的权重为β(1−β)\beta(1-\beta)β(1−β),t−3t-3t−3步为β(1−β)2\beta(1-\beta)^2β(1−β)2)。这种平滑能避免单次实时扰动的噪声,同时减少“重复计算EtBE_t^BEtB”的梯度开销。
六、核心推导2:EMA扰动的理论误差界(定理2证明)
EFlat-LoRA的关键理论保障是:EMA扰动E^tB\hat{E}_t^BE^tB与Flat-LoRA的实时扰动EtBE_t^BEtB计算的“尖锐度”误差,会随迭代次数ttt增大而减小。需先明确“尖锐度”的定义,再通过4个标准假设推导误差界。
1. 关键定义:尖锐度与误差
- 尖锐度(Sharpness):衡量损失 landscape 的“平坦程度”,原文中用“扰动前后的损失差”定义:
- SAM/Flat-LoRA的尖锐度(真实尖锐度):StSAM=L(wt+ε~t)−L(wt)S_t^{SAM} = L(w_t + \tilde{\varepsilon}_t) - L(w_t)StSAM=L(wt+ε~t)−L(wt),其中ε~t=Vector(EtB⋅At)\tilde{\varepsilon}_t = \text{Vector}(E_t^B \cdot A_t)ε~t=Vector(EtB⋅At)(B矩阵扰动转化为全参数空间的扰动向量);
- EFlat-LoRA的尖锐度(EMA近似尖锐度):StEMA=L(wt+ε^t−1)−L(wt)S_t^{EMA} = L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t)StEMA=L(wt+ε^t−1)−L(wt),其中ε^t−1=Vector(E^t−1B⋅At)\hat{\varepsilon}_{t-1} = \text{Vector}(\hat{E}_{t-1}^B \cdot A_t)ε^t−1=Vector(E^t−1B⋅At)(EMA扰动转化为全参数空间的扰动向量)。
- 误差目标:证明∣StEMA−StSAM∣|S_t^{EMA} - S_t^{SAM}|∣StEMA−StSAM∣随ttt收敛到0。
2. 推导前提:4个标准假设
原文明确,误差分析需基于4个在SAM类方法中通用的假设(确保推导严谨性):
- Lipschitz平滑性(Assumption 1):损失函数L(w)L(w)L(w)关于参数www满足TTT-Lipschitz平滑,即对任意参数w,vw, vw,v:
∥∇L(w)−∇L(v)∥≤T⋅∥w−v∥ \|\nabla L(w) - \nabla L(v)\| \leq T \cdot \|w - v\|∥∇L(w)−∇L(v)∥≤T⋅∥w−v∥
(TTT为Lipschitz常数,衡量梯度变化的平缓程度); - 梯度有界(Assumption 2):每个mini-batch的梯度范数存在上界GGG,即:
E[∥∇L(w)∥]≤G\mathbb{E}\left[\|\nabla L(w)\|\right] \leq GE[∥∇L(w)∥]≤G
(E[⋅]\mathbb{E}[\cdot]E[⋅]为期望,避免单个batch的异常梯度影响); - 随机梯度方差有界(Assumption 3):设训练集为DDD,mini-batch为B⊆DB \subseteq DB⊆D,随机梯度∇LB(w)\nabla L_B(w)∇LB(w)与全数据集梯度∇LD(w)\nabla L_D(w)∇LD(w)的方差存在上界σ2\sigma^2σ2:
E[∥∇LB(w)−∇LD(w)∥2]≤σ2 \mathbb{E}\left[\|\nabla L_B(w) - \nabla L_D(w)\|^2\right] \leq \sigma^2E[∥∇LB(w)−∇LD(w)∥2]≤σ2
(控制mini-batch采样带来的梯度噪声); - 局部凸性(Assumption 4):微调阶段,模型参数接近局部极小值,损失函数在局部邻域内凸且二次可微,即对任意x,yx, yx,y:
L(y)≥L(x)+∇L(x)⊤(y−x)L(y) \geq L(x) + \nabla L(x)^\top (y - x)L(y)≥L(x)+∇L(x)⊤(y−x)
(凸性保证误差分析可通过不等式链推导,无局部震荡)。
3. 定理2:EMA扰动的误差界证明
定理2表述:若满足上述4个假设,且扰动半径随迭代衰减ρt=ρ0t\rho_t = \frac{\rho_0}{\sqrt{t}}ρt=tρ0(ρ0\rho_0ρ0为初始半径),则ttt步时EMA尖锐度与SAM尖锐度的误差满足:
∣StEMA−StSAM∣≤(T⋅ρ0t−1+G+σ2)⋅(ρ0t+ρ0(1−β)t−1+ρ0) \left| S_t^{EMA} - S_t^{SAM} \right| \leq \left( T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2 \right) \cdot \left( \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0 \right) StEMA−StSAM≤(T⋅t−1ρ0+G+σ2)⋅(tρ0+ρ0(1−β)t−1+ρ0)
推导步骤:
-
用Lipschitz平滑性展开损失差:
对StEMA=L(wt+ε^t−1)−L(wt)S_t^{EMA} = L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t)StEMA=L(wt+ε^t−1)−L(wt)和StSAM=L(wt+ε~t)−L(wt)S_t^{SAM} = L(w_t + \tilde{\varepsilon}_t) - L(w_t)StSAM=L(wt+ε~t)−L(wt),由Lipschitz平滑性的推论(损失差≤梯度范数×参数差):
∣L(wt+ε^t−1)−L(wt+ε~t)∣≤supw∈[wt+ε^t−1,wt+ε~t]∥∇L(w)∥⋅∥ε^t−1−ε~t∥ \left| L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t + \tilde{\varepsilon}_t) \right| \leq \sup_{w \in [w_t+\hat{\varepsilon}_{t-1}, w_t+\tilde{\varepsilon}_t]} \|\nabla L(w)\| \cdot \|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\|∣L(wt+ε^t−1)−L(wt+ε~t)∣≤w∈[wt+ε^t−1,wt+ε~t]sup∥∇L(w)∥⋅∥ε^t−1−ε~t∥
其中[a,b][a,b][a,b]表示aaa与bbb之间的参数区间,sup\supsup为上确界。 -
控制梯度范数的上界:
由Assumption 2(梯度有界E[∥∇L(w)∥]≤G\mathbb{E}[\|\nabla L(w)\|] \leq GE[∥∇L(w)∥]≤G)和Assumption 3(方差有界σ2\sigma^2σ2),结合Cauchy-Schwarz不等式,可推出:
supw∈[wt+ε^t−1,wt+ε~t]∥∇L(w)∥≤∥∇L(wt)∥+T⋅max(∥ε^t−1∥,∥ε~t∥)+σ2 \sup_{w \in [w_t+\hat{\varepsilon}_{t-1}, w_t+\tilde{\varepsilon}_t]} \|\nabla L(w)\| \leq \|\nabla L(w_t)\| + T \cdot \max\left(\|\hat{\varepsilon}_{t-1}\|, \|\tilde{\varepsilon}_t\|\right) + \sigma^2 w∈[wt+ε^t−1,wt+ε~t]sup∥∇L(w)∥≤∥∇L(wt)∥+T⋅max(∥ε^t−1∥,∥ε~t∥)+σ2
再代入∥∇L(wt)∥≤G\|\nabla L(w_t)\| \leq G∥∇L(wt)∥≤G(Assumption 2)和max(∥ε^t−1∥,∥ε~t∥)≤ρt−1=ρ0t−1\max\left(\|\hat{\varepsilon}_{t-1}\|, \|\tilde{\varepsilon}_t\|\right) \leq \rho_{t-1} = \frac{\rho_0}{\sqrt{t-1}}max(∥ε^t−1∥,∥ε~t∥)≤ρt−1=t−1ρ0(扰动半径衰减),得:
sup∥∇L(w)∥≤T⋅ρ0t−1+G+σ2 \sup \|\nabla L(w)\| \leq T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2sup∥∇L(w)∥≤T⋅t−1ρ0+G+σ2 -
控制参数差∥ε^t−1−ε~t∥\|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\|∥ε^t−1−ε~t∥的上界:
由EMA定义E^t−1B=(1−β)E^t−2B+βEt−1B\hat{E}_{t-1}^B = (1-\beta)\hat{E}_{t-2}^B + \beta E_{t-1}^BE^t−1B=(1−β)E^t−2B+βEt−1B,递推可得E^t−1B\hat{E}_{t-1}^BE^t−1B是历史E1B,...,Et−1BE_1^B, ..., E_{t-1}^BE1B,...,Et−1B的加权和:
E^t−1B=β∑k=1t−1(1−β)t−1−kEkB\hat{E}_{t-1}^B = \beta \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} E_k^BE^t−1B=βk=1∑t−1(1−β)t−1−kEkB
因此,ε^t−1=Vector(E^t−1BAt)\hat{\varepsilon}_{t-1} = \text{Vector}(\hat{E}_{t-1}^B A_t)ε^t−1=Vector(E^t−1BAt)与ε~t=Vector(EtBAt)\tilde{\varepsilon}_t = \text{Vector}(E_t^B A_t)ε~t=Vector(EtBAt)的差的范数满足:
∥ε^t−1−ε~t∥≤∥ε^t−1∥+∥ε~t∥ \|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\| \leq \|\hat{\varepsilon}_{t-1}\| + \|\tilde{\varepsilon}_t\|∥ε^t−1−ε~t∥≤∥ε^t−1∥+∥ε~t∥
(三角不等式)。
其中:- ∥ε~t∥≤ρt=ρ0t\|\tilde{\varepsilon}_t\| \leq \rho_t = \frac{\rho_0}{\sqrt{t}}∥ε~t∥≤ρt=tρ0(当前扰动半径);
- ∥ε^t−1∥≤β∑k=1t−1(1−β)t−1−k⋅∥εk∥≤βρ0∑k=1t−1(1−β)t−1−k=ρ0(1−(1−β)t−1)\|\hat{\varepsilon}_{t-1}\| \leq \beta \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} \cdot \|\varepsilon_k\| \leq \beta \rho_0 \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} = \rho_0 (1 - (1-\beta)^{t-1})∥ε^t−1∥≤β∑k=1t−1(1−β)t−1−k⋅∥εk∥≤βρ0∑k=1t−1(1−β)t−1−k=ρ0(1−(1−β)t−1)(等比数列求和:∑k=0n(1−β)k=1−(1−β)n+1β\sum_{k=0}^{n} (1-\beta)^k = \frac{1 - (1-\beta)^{n+1}}{\beta}∑k=0n(1−β)k=β1−(1−β)n+1);
代入得:
∥ε^t−1−ε~t∥≤ρ0(1−(1−β)t−1)+ρ0t≤ρ0t+ρ0(1−β)t−1+ρ0\|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\| \leq \rho_0 (1 - (1-\beta)^{t-1}) + \frac{\rho_0}{\sqrt{t}} \leq \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0∥ε^t−1−ε~t∥≤ρ0(1−(1−β)t−1)+tρ0≤tρ0+ρ0(1−β)t−1+ρ0
(因1−(1−β)t−1≤1+ρ0(1−β)t−11 - (1-\beta)^{t-1} \leq 1 + \rho_0 (1-\beta)^{t-1}1−(1−β)t−1≤1+ρ0(1−β)t−1,简化后不影响收敛性)。
-
合并误差界:
将步骤2和步骤3的结果代入步骤1的不等式,最终得到:
∣StEMA−StSAM∣≤(T⋅ρ0t−1+G+σ2)⋅(ρ0t+ρ0(1−β)t−1+ρ0) \left| S_t^{EMA} - S_t^{SAM} \right| \leq \left( T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2 \right) \cdot \left( \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0 \right) StEMA−StSAM≤(T⋅t−1ρ0+G+σ2)⋅(tρ0+ρ0(1−β)t−1+ρ0)
关键结论:当迭代次数t→+∞t \to +\inftyt→+∞时,ρ0t→0\frac{\rho_0}{\sqrt{t}} \to 0tρ0→0且(1−β)t−1→0(1-\beta)^{t-1} \to 0(1−β)t−1→0(β∈(0,1)\beta \in (0,1)β∈(0,1)),因此误差∣StEMA−StSAM∣→0\left| S_t^{EMA} - S_t^{SAM} \right| \to 0StEMA−StSAM→0,即EMA扰动能无限逼近Flat-LoRA的实时扰动。
七、核心推导3:EFlat-LoRA的计算复杂度验证
EFlat-LoRA的“效率”需通过参数复杂度、内存复杂度、时间复杂度的数学推导验证,证明其与LoRA效率相当。
1. 参数复杂度(训练参数数量)
EFlat-LoRA未新增可训练参数,仅复用LoRA的AAA(r×mr \times mr×m)和BBB(n×rn \times rn×r)矩阵,因此参数数量与LoRA、Flat-LoRA完全一致:
PEFlat-LoRA=PLoRA=PFlat-LoRA=O(nr+rm)≪O(nm)P_{\text{EFlat-LoRA}} = P_{\text{LoRA}} = P_{\text{Flat-LoRA}} = O(nr + rm) \ll O(nm) PEFlat-LoRA=PLoRA=PFlat-LoRA=O(nr+rm)≪O(nm)
其中n×mn \times mn×m为原模型权重矩阵维度,r≪min(n,m)r \ll \min(n,m)r≪min(n,m)(低秩约束),保证参数效率。
2. 内存复杂度(存储开销)
EFlat-LoRA的额外内存开销仅来自“EMA扰动E^tB\hat{E}_t^BE^tB的存储”(n×rn \times rn×r矩阵,因仅对B加扰动),需对比Flat-LoRA:
- Flat-LoRA需存储“AAA的梯度+原A/BA/BA/B备份”,内存为MLoRA+O(1.5(nr+rm))M_{\text{LoRA}} + O(1.5(nr + rm))MLoRA+O(1.5(nr+rm));
- EFlat-LoRA需存储“EMA扰动E^tB\hat{E}_t^BE^tB”(O(nr)O(nr)O(nr)),且现代优化器(如AdamW)已需存储“动量+二阶动量”(O(2(nr+rm))O(2(nr + rm))O(2(nr+rm))),因此EFlat-LoRA的内存为:
MEFlat-LoRA=MLoRA+O(2(nr+rm)) M_{\text{EFlat-LoRA}} = M_{\text{LoRA}} + O(2(nr + rm)) MEFlat-LoRA=MLoRA+O(2(nr+rm)) 。
3. 时间复杂度(训练耗时)
时间复杂度的核心是“前向-反向传播次数”:
- LoRA:每次迭代1次前向+1次反向,时间复杂度O(T)O(T)O(T)(TTT为单轮前向-反向时间);
- Flat-LoRA:每次迭代需2次前向+2次反向(原参数点+扰动参数点),时间复杂度O(2T)O(2T)O(2T);
- EFlat-LoRA:因用EMA扰动E^t−1B\hat{E}_{t-1}^BE^t−1B近似实时扰动EtBE_t^BEtB,仅需在“更新EMA时计算1次EtBE_t^BEtB”,每次迭代仅1次前向+1次反向,时间复杂度:
TEFlat-LoRA=O(T)=TLoRA T_{\text{EFlat-LoRA}} = O(T) = T_{\text{LoRA}} TEFlat-LoRA=O(T)=TLoRA
(与LoRA效率完全一致,解决Flat-LoRA的耗时问题)。