当前位置: 首页 > news >正文

EFlat-LoRA 的严格数学推导

EFlat-LoRA 的严格数学推导

一、Flat-LoRA基础

首先明确 LoRA 的权重表示、SAM 的优化框架,以及二者直接结合的核心矛盾,为 Flat-LoRA 的推导奠定基础。

1. LoRA 的权重更新公式

LoRA 通过低秩分解表示预训练模型的权重变化,对于任意层的预训练权重 $ W_0 \in \mathbb{R}^{n \times m} $,其微调后的权重为:
W=W0+ΔW=W0+sBA(6) W = W_0 + \Delta W = W_0 + sBA \tag{6}W=W0+ΔW=W0+sBA(6)
其中:

  • sss 为缩放因子(控制低秩更新的幅度);
  • B∈Rn×rB \in \mathbb{R}^{n \times r}BRn×rA∈Rr×mA \in \mathbb{R}^{r \times m}ARr×m 为 LoRA 待优化的低秩矩阵(r≪min⁡(n,m)r \ll \min(n,m)rmin(n,m),保证参数效率);
  • 训练中 W0W_0W0 冻结,仅更新AAA(Kaiming 初始化)和 BBB(零初始化)。
2. SAM 与 LoRA 直接结合的矛盾(Naive 方案)

SAM 的核心是通过“极小化最大扰动损失”寻找平坦极小值,其优化目标为:
min⁡wmax⁡∥ε∥≤ρL(w+ε) \min_{w} \max_{\|\varepsilon\| \leq \rho} L(w+\varepsilon)wminερmaxL(w+ε)
若直接将 SAM 与 LoRA 结合,需对 AAABBB 分别加扰动EA∈Rr×mE^A \in \mathbb{R}^{r \times m}EARr×mEB∈Rn×rE^B \in \mathbb{R}^{n \times r}EBRn×r(均满足 Frobenius 范数约束∥EA∥F≤ρ\|E^A\|_F \leq \rhoEAFρ∥EB∥F≤ρ\|E^B\|_F \leq \rhoEBFρ),优化目标变为:
min⁡A,Bmax⁡∥EA∥F≤ρ∥EB∥F≤ρL(W0+s(B+EB)(A+EA))(7)\min_{A,B} \max_{\substack{\|E^A\|_F \leq \rho \\ \|E^B\|_F \leq \rho}} L\left(W_0 + s(B+E^B)(A+E^A)\right) \tag{7}A,BminEAFρEBFρmaxL(W0+s(B+EB)(A+EA))(7)
核心矛盾

  • 双扰动EAE^AEAEBE^BEB 互相干扰,导致低秩子空间计算的“最大损失”与全参数空间(W∈Rn×mW \in \mathbb{R}^{n \times m}WRn×m)的“最大损失”不一致,无法精准找到平坦极小值;
  • 全参数梯度 ∇LW(W)\nabla L_W(W)LW(W)未知(LoRA 仅优化 AAABBB),无法直接套用 SAM 的扰动计算逻辑。

二、Flat-LoRA 的核心推导:扰动重参数化

Flat-LoRA 的核心是将“全参数空间的扰动”通过数学变换,转化为“单一低秩矩阵(B)的扰动”,既对齐全参数空间的损失逻辑,又避免双扰动干扰。

1. 步骤1:全参数空间的扰动建模

首先跳出 LoRA 的低秩子空间,直接在全参数空间定义 SAM 优化目标(目标是找到全参数空间的最大扰动损失):
min⁡A,Bmax⁡∥EW∥F≤ρL(W0+sBA+EW)(8) \min_{A,B} \max_{\|E^W\|_F \leq \rho} L\left(W_0 + sBA + E^W\right) \tag{8} A,BminEWFρmaxL(W0+sBA+EW)(8)
其中 EW∈Rn×mE^W \in \mathbb{R}^{n \times m}EWRn×m是全参数空间的扰动(Frobenius 范数约束),W=W0+sBAW = W_0 + sBAW=W0+sBA 是 LoRA 微调后的全权重。

要解决式 (8) 的“min-max”问题,需先求内层的“max”(即找到最优扰动 E^W\hat{E}^WE^W,使L(W+EW)L(W + E^W)L(W+EW) 最大)。类比 SAM 的泰勒展开近似:
L(W+EW)L(W + E^W)L(W+EW)WWW 处一阶泰勒展开,忽略高阶项:
L(W+EW)≈L(W)+Vector(EW)⊤⋅Vector(∇LW(W))L(W + E^W) \approx L(W) + \text{Vector}(E^W)^\top \cdot \text{Vector}\left(\nabla L_W(W)\right)L(W+EW)L(W)+Vector(EW)Vector(LW(W))
其中 ∇LW(W)=∂L∂W\nabla L_W(W) = \frac{\partial L}{\partial W}LW(W)=WL 是损失对全权重WWW 的梯度,Vector(⋅)\text{Vector}(·)Vector()表示“矩阵向量化”操作(将 n×mn \times mn×m矩阵转为 nm×1nm \times 1nm×1向量)。

根据柯西-施瓦茨不等式,当 Vector(EW)\text{Vector}(E^W)Vector(EW)Vector(∇LW(W))\text{Vector}(\nabla L_W(W))Vector(LW(W)) 同方向时,内积最大。因此最优扰动的向量化形式为:
ε^w=ρ⋅sign(gw)⋅gw∥gw∥(9)\hat{\varepsilon}^w = \rho \cdot \text{sign}(g^w) \cdot \frac{g^w}{\|g^w\|} \tag{9}ε^w=ρsign(gw)gwgw(9)
其中 gw=Vector(∇LW(W))g^w = \text{Vector}\left(\nabla L_W(W)\right)gw=Vector(LW(W))sign(⋅)\text{sign}(·)sign()为符号函数,∥⋅∥\|·\|L2L_2L2范数。

将向量化结果还原为矩阵,得到全参数空间的最优扰动:
E^W=Matrix(ε^w)\hat{E}^W = \text{Matrix}\left(\hat{\varepsilon}^w\right) E^W=Matrix(ε^w)
其中 Matrix(⋅)\text{Matrix}(·)Matrix() 是“向量矩阵化”操作(与 Vector(⋅)\text{Vector}(·)Vector() 互为逆操作)。

2. 步骤2:用 LoRA 梯度近似全参数梯度∇LW(W)\nabla L_W(W)LW(W)

式 (9) 中的 ∇LW(W)\nabla L_W(W)LW(W)是未知的(LoRA 仅优化 AAABBB,不直接计算全权重梯度),因此需要通过 LoRA 可计算的梯度(∇LA\nabla L_ALA∇LB\nabla L_BLB)结合伪逆 近似。

根据链式法则,推导 ∇LA\nabla L_ALA∇LB\nabla L_BLB∇LW(W)\nabla L_W(W)LW(W)的关系:

  • AAA 求梯度:∇LA=∂L∂A=∂L∂W⋅∂W∂A=s⋅B⊤⋅∇LW(W)\nabla L_A = \frac{\partial L}{\partial A} = \frac{\partial L}{\partial W} \cdot \frac{\partial W}{\partial A} = s \cdot B^\top \cdot \nabla L_W(W)LA=AL=WLAW=sBLW(W)
  • BBB 求梯度:∇LB=∂L∂B=∂L∂W⋅∂W∂B=s⋅∇LW(W)⋅A⊤\nabla L_B = \frac{\partial L}{\partial B} = \frac{\partial L}{\partial W} \cdot \frac{\partial W}{\partial B} = s \cdot \nabla L_W(W) \cdot A^\topLB=BL=WLBW=sLW(W)A

对上述两式两边分别左乘/右乘伪逆(解决矩阵不可逆问题),解出∇LW(W)\nabla L_W(W)LW(W) 的两种近似:

  1. ∇LB\nabla L_BLB推导:两边右乘(A⊤)+(A^\top)^+(A)+A⊤A^\topA 的伪逆),得:
    nablaLW(W)≈1s⋅∇LB⋅(A⊤)+(10)nabla L_W(W) \approx \frac{1}{s} \cdot \nabla L_B \cdot (A^\top)^+ \tag{10}nablaLW(W)s1LB(A)+(10)
  2. 由$\nabla L_A $ 推导:两边左乘 (B⊤)+(B^\top)^+(B)+B⊤B^\topB 的伪逆),得:
    ∇LW(W)≈1s⋅(B⊤)+⋅∇LA(11)\nabla L_W(W) \approx \frac{1}{s} \cdot (B^\top)^+ \cdot \nabla L_A \tag{11}LW(W)s1(B)+LA(11)

为提升近似精度,取两式的平均作为最终的全参数梯度估计:
∇LW(W)‾=0.5⋅[1s∇LB(A⊤)++1s(B⊤)+∇LA](12)\overline{\nabla L_W(W)} = 0.5 \cdot \left[ \frac{1}{s} \nabla L_B (A^\top)^+ + \frac{1}{s} (B^\top)^+ \nabla L_A \right] \tag{12}LW(W)=0.5[s1LB(A)++s1(B)+LA](12)

3. 步骤3:将全参数扰动转化为单一低秩矩阵(B)的扰动

目标是让 LoRA 低秩子空间的扰动损失,与全参数空间的最大损失完全对齐(即式 (8) 的内层“max”等于低秩扰动的损失):
L(W0+s(B+EB)A)=max⁡∥EW∥F≤ρL(W0+sBA+EW)(14)L\left(W_0 + s(B+E^B)A\right) = \max_{\|E^W\|_F \leq \rho} L\left(W_0 + sBA + E^W\right) \tag{14}L(W0+s(B+EB)A)=EWFρmaxL(W0+sBA+EW)(14)

将步骤1得到的 E^W\hat{E}^WE^W 代入右边(全参数最大损失对应 $E^W = \hat{E}^W $),展开左边:
W0+sBA+sEBA=W0+sBA+E^WW_0 + sBA + sE^B A = W_0 + sBA + \hat{E}^WW0+sBA+sEBA=W0+sBA+E^W

两边消去W0+sBAW_0 + sBAW0+sBA,得sEBA≈E^WsE^B A \approx \hat{E}^WsEBAE^W。对两边右乘A+A^+A+AAA 的伪逆),最终解出BBB的扰动:
EB≈1s⋅E^W⋅A+(15)E^B \approx \frac{1}{s} \cdot \hat{E}^W \cdot A^+ \tag{15}EBs1E^WA+(15)

选择 B 而非 A 加扰动的原因AAA因初始化特性捕捉“跨任务通用特征”,BBB捕捉“任务特定特征”;对BBB加扰动可适配不同任务需求,避免通用特征被干扰。

三、Flat-LoRA 的平衡性质推导(定理1)

为验证 Flat-LoRA 的优化稳定性,文档定义“平衡度”(反映AAABBB 参数更新的均衡性),并推导其随训练的变化规律。

1. 平衡度定义

xt=Vector(Bt)x_t = \text{Vector}(B_t)xt=Vector(Bt)BtB_tBt 为 t 步的低秩矩阵),yt=Vector(At)y_t = \text{Vector}(A_t)yt=Vector(At),平衡度定义为:
Bt=12(∥xt∥2−∥yt∥2)B_t = \frac{1}{2} \left( \|x_t\|^2 - \|y_t\|^2 \right)Bt=21(xt2yt2)

2. Flat-LoRA 的参数更新过程

Flat-LoRA 的更新分“扰动→梯度计算→参数更新”三步:

  1. 扰动 xtx_txtx~t=xt+ρ⋅1s⋅Gt∥Gt∥⋅yt+\tilde{x}_t = x_t + \rho \cdot \frac{1}{s} \cdot \frac{G_t}{\|G_t\|} \cdot y_t^+x~t=xt+ρs1GtGtyt+y~t=yt\tilde{y}_t = y_ty~t=ytGt=∇L(xtyt⊤)G_t = \nabla L(x_t y_t^\top)Gt=L(xtyt) 为原参数点的全梯度,yt+y_t^+yt+yty_tyt 的伪逆);
  2. 计算扰动点梯度:gx‾t=G~ty‾tg_{\overline{x}_t} = \tilde{G}_t \overline{y}_tgxt=G~tytgy‾t=G~t⊤x~tg_{\overline{y}_t} = \tilde{G}_t^\top \tilde{x}_tgyt=G~tx~ttildeGt=∇L(x~ty~t⊤)tilde{G}_t = \nabla L(\tilde{x}_t \tilde{y}_t^\top)tildeGt=L(x~ty~t) 为扰动点的全梯度);
  3. 更新参数:xt+1=xt−ηgx‾tx_{t+1} = x_t - \eta g_{\overline{x}_t}xt+1=xtηgxtyt+1=yt−ηgy‾ty_{t+1} = y_t - \eta g_{\overline{y}_t}yt+1=ytηgytη\etaη为学习率)。
3. 平衡度的导数约束(定理1)

BtB_tBt 关于训练步长 ttt求导(当 η→0\eta \to 0η0 时,参数更新趋近连续流):
dBtdt=12(2xt⊤x˙t−2yt⊤y˙t)=xt⊤x˙t−yt⊤y˙t\frac{d B_t}{d t} = \frac{1}{2} \left( 2x_t^\top \dot{x}_t - 2y_t^\top \dot{y}_t \right) = x_t^\top \dot{x}_t - y_t^\top \dot{y}_tdtdBt=21(2xtx˙t2yty˙t)=xtx˙tyty˙t

代入 x˙t=−lim⁡η→0xt−xt+1η=−gx‾t\dot{x}_t = -\lim_{\eta \to 0} \frac{x_t - x_{t+1}}{\eta} = -g_{\overline{x}_t}x˙t=limη0ηxtxt+1=gxty˙t=−gy‾t\dot{y}_t = -g_{\overline{y}_t}y˙t=gyt,结合 x~t\tilde{x}_tx~t 的扰动形式,最终通过范数约束推导得:
12d(∥xt∥2−∥yt∥2)dt∣≤∣ρ⋅1s⋅1∥yt∥⋅∥gx‾t∥∣(17)\left. \frac{1}{2} \frac{d \left( \|x_t\|^2 - \|y_t\|^2 \right)}{d t} \right| \leq \left| \rho \cdot \frac{1}{s} \cdot \frac{1}{\|y_t\|} \cdot \|g_{\overline{x}_t}\| \right| \tag{17} 21dtd(xt2yt2)ρs1yt1gxt(17)

该式表明:Flat-LoRA 的平衡度会随训练逐渐降低(因 ρ\rhoρ 随训练衰减、∥gx‾t∥\|g_{\overline{x}_t}\|gxt 因权重衰减减小),保证参数更新的均衡性,避免优化不稳定。

四、EFlat-LoRA推导前提

EFlat-LoRA是Flat-LoRA的效率优化版,需先明确Flat-LoRA的核心输出——B矩阵的实时扰动EtBE_t^BEtB
由Flat-LoRA推导可知,全参数空间的最优扰动E^tW\hat{E}_t^WE^tW可转化为LoRA中B矩阵的实时扰动(AAA矩阵固定,无扰动EA=0E^A=0EA=0):
EtB=1s⋅E^tW⋅At+E_t^B = \frac{1}{s} \cdot \hat{E}_t^W \cdot A_t^+ EtB=s1E^tWAt+
其中:sss为LoRA缩放因子,At+A_t^+At+ttt步迭代时AAA矩阵的伪逆,E^tW\hat{E}_t^WE^tWttt步全参数空间的最优扰动(由Flat-LoRA的梯度近似得到)。

Flat-LoRA的痛点是:计算EtBE_t^BEtB两次梯度计算(一次原参数点WtW_tWt,一次扰动参数点Wt+E^tWW_t+\hat{E}_t^WWt+E^tW),时间复杂度为O(2T)O(2T)O(2T)TTT为LoRA的时间复杂度)。EFlat-LoRA通过“EMA估计EtBE_t^BEtB”将梯度计算次数降为1次,核心是用“历史扰动的加权平均”替代“实时扰动的重复计算”。

五、核心推导1:EMA扰动的数学定义与迭代更新

EFlat-LoRA的核心操作是用EMA扰动E^tB\hat{E}_t^BE^tB近似Flat-LoRA的实时扰动EtBE_t^BEtB,避免每次迭代重新计算EtBE_t^BEtB的梯度开销。

1. EMA扰动的定义

对B矩阵的实时扰动EtBE_t^BEtBttt步迭代时的真实扰动),用“指数移动平均”计算平滑后的扰动E^tB\hat{E}_t^BE^tB,迭代公式为:
E^tB=(1−β)⋅E^t−1B+β⋅EtB\hat{E}_t^B = (1-\beta) \cdot \hat{E}_{t-1}^B + \beta \cdot E_t^B E^tB=(1β)E^t1B+βEtB
其中:

  • β∈(0,1)\beta \in (0,1)β(0,1)为动量系数(控制历史扰动的权重,通常取0.9~0.99,原文未指定具体值,需通过实验调优);
  • E^t−1B\hat{E}_{t-1}^BE^t1Bt−1t-1t1步的EMA扰动(初始值E^0B=0\hat{E}_0^B=0E^0B=0,即首次迭代用实时扰动E1BE_1^BE1B);
  • EtBE_t^BEtBttt步Flat-LoRA的实时扰动(由式(15)计算,仅在“更新EMA”时需计算1次,无需重复用于梯度下降)。
2. EMA扰动的物理意义

EMA的本质是对历史扰动进行“指数加权平滑”:近期扰动(EtBE_t^BEtB)权重为β\betaβ,远期扰动(E^t−1B\hat{E}_{t-1}^BE^t1B)权重为(1−β)(1-\beta)(1β),且权重随时间指数衰减(如β=0.9\beta=0.9β=0.9时,t−2t-2t2步扰动的权重为β(1−β)\beta(1-\beta)β(1β)t−3t-3t3步为β(1−β)2\beta(1-\beta)^2β(1β)2)。这种平滑能避免单次实时扰动的噪声,同时减少“重复计算EtBE_t^BEtB”的梯度开销。

六、核心推导2:EMA扰动的理论误差界(定理2证明)

EFlat-LoRA的关键理论保障是:EMA扰动E^tB\hat{E}_t^BE^tB与Flat-LoRA的实时扰动EtBE_t^BEtB计算的“尖锐度”误差,会随迭代次数ttt增大而减小。需先明确“尖锐度”的定义,再通过4个标准假设推导误差界。

1. 关键定义:尖锐度与误差
  • 尖锐度(Sharpness):衡量损失 landscape 的“平坦程度”,原文中用“扰动前后的损失差”定义:
    • SAM/Flat-LoRA的尖锐度(真实尖锐度):StSAM=L(wt+ε~t)−L(wt)S_t^{SAM} = L(w_t + \tilde{\varepsilon}_t) - L(w_t)StSAM=L(wt+ε~t)L(wt),其中ε~t=Vector(EtB⋅At)\tilde{\varepsilon}_t = \text{Vector}(E_t^B \cdot A_t)ε~t=Vector(EtBAt)(B矩阵扰动转化为全参数空间的扰动向量);
    • EFlat-LoRA的尖锐度(EMA近似尖锐度):StEMA=L(wt+ε^t−1)−L(wt)S_t^{EMA} = L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t)StEMA=L(wt+ε^t1)L(wt),其中ε^t−1=Vector(E^t−1B⋅At)\hat{\varepsilon}_{t-1} = \text{Vector}(\hat{E}_{t-1}^B \cdot A_t)ε^t1=Vector(E^t1BAt)(EMA扰动转化为全参数空间的扰动向量)。
  • 误差目标:证明∣StEMA−StSAM∣|S_t^{EMA} - S_t^{SAM}|StEMAStSAMttt收敛到0。
2. 推导前提:4个标准假设

原文明确,误差分析需基于4个在SAM类方法中通用的假设(确保推导严谨性):

  1. Lipschitz平滑性(Assumption 1):损失函数L(w)L(w)L(w)关于参数www满足TTT-Lipschitz平滑,即对任意参数w,vw, vw,v
    ∥∇L(w)−∇L(v)∥≤T⋅∥w−v∥ \|\nabla L(w) - \nabla L(v)\| \leq T \cdot \|w - v\|∥∇L(w)L(v)Twv
    TTT为Lipschitz常数,衡量梯度变化的平缓程度);
  2. 梯度有界(Assumption 2):每个mini-batch的梯度范数存在上界GGG,即:
    E[∥∇L(w)∥]≤G\mathbb{E}\left[\|\nabla L(w)\|\right] \leq GE[∥∇L(w)]G
    E[⋅]\mathbb{E}[\cdot]E[]为期望,避免单个batch的异常梯度影响);
  3. 随机梯度方差有界(Assumption 3):设训练集为DDD,mini-batch为B⊆DB \subseteq DBD,随机梯度∇LB(w)\nabla L_B(w)LB(w)与全数据集梯度∇LD(w)\nabla L_D(w)LD(w)的方差存在上界σ2\sigma^2σ2
    E[∥∇LB(w)−∇LD(w)∥2]≤σ2 \mathbb{E}\left[\|\nabla L_B(w) - \nabla L_D(w)\|^2\right] \leq \sigma^2E[∥∇LB(w)LD(w)2]σ2
    (控制mini-batch采样带来的梯度噪声);
  4. 局部凸性(Assumption 4):微调阶段,模型参数接近局部极小值,损失函数在局部邻域内凸且二次可微,即对任意x,yx, yx,y
    L(y)≥L(x)+∇L(x)⊤(y−x)L(y) \geq L(x) + \nabla L(x)^\top (y - x)L(y)L(x)+L(x)(yx)
    (凸性保证误差分析可通过不等式链推导,无局部震荡)。
3. 定理2:EMA扰动的误差界证明

定理2表述:若满足上述4个假设,且扰动半径随迭代衰减ρt=ρ0t\rho_t = \frac{\rho_0}{\sqrt{t}}ρt=tρ0ρ0\rho_0ρ0为初始半径),则ttt步时EMA尖锐度与SAM尖锐度的误差满足:
∣StEMA−StSAM∣≤(T⋅ρ0t−1+G+σ2)⋅(ρ0t+ρ0(1−β)t−1+ρ0) \left| S_t^{EMA} - S_t^{SAM} \right| \leq \left( T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2 \right) \cdot \left( \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0 \right) StEMAStSAM(Tt1ρ0+G+σ2)(tρ0+ρ0(1β)t1+ρ0)

推导步骤

  1. 用Lipschitz平滑性展开损失差
    StEMA=L(wt+ε^t−1)−L(wt)S_t^{EMA} = L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t)StEMA=L(wt+ε^t1)L(wt)StSAM=L(wt+ε~t)−L(wt)S_t^{SAM} = L(w_t + \tilde{\varepsilon}_t) - L(w_t)StSAM=L(wt+ε~t)L(wt),由Lipschitz平滑性的推论(损失差≤梯度范数×参数差):
    ∣L(wt+ε^t−1)−L(wt+ε~t)∣≤sup⁡w∈[wt+ε^t−1,wt+ε~t]∥∇L(w)∥⋅∥ε^t−1−ε~t∥ \left| L(w_t + \hat{\varepsilon}_{t-1}) - L(w_t + \tilde{\varepsilon}_t) \right| \leq \sup_{w \in [w_t+\hat{\varepsilon}_{t-1}, w_t+\tilde{\varepsilon}_t]} \|\nabla L(w)\| \cdot \|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\|L(wt+ε^t1)L(wt+ε~t)w[wt+ε^t1,wt+ε~t]sup∥∇L(w)ε^t1ε~t
    其中[a,b][a,b][a,b]表示aaabbb之间的参数区间,sup⁡\supsup为上确界。

  2. 控制梯度范数的上界
    由Assumption 2(梯度有界E[∥∇L(w)∥]≤G\mathbb{E}[\|\nabla L(w)\|] \leq GE[∥∇L(w)]G)和Assumption 3(方差有界σ2\sigma^2σ2),结合Cauchy-Schwarz不等式,可推出:
    sup⁡w∈[wt+ε^t−1,wt+ε~t]∥∇L(w)∥≤∥∇L(wt)∥+T⋅max⁡(∥ε^t−1∥,∥ε~t∥)+σ2 \sup_{w \in [w_t+\hat{\varepsilon}_{t-1}, w_t+\tilde{\varepsilon}_t]} \|\nabla L(w)\| \leq \|\nabla L(w_t)\| + T \cdot \max\left(\|\hat{\varepsilon}_{t-1}\|, \|\tilde{\varepsilon}_t\|\right) + \sigma^2 w[wt+ε^t1,wt+ε~t]sup∥∇L(w)∥∇L(wt)+Tmax(ε^t1,ε~t)+σ2
    再代入∥∇L(wt)∥≤G\|\nabla L(w_t)\| \leq G∥∇L(wt)G(Assumption 2)和max⁡(∥ε^t−1∥,∥ε~t∥)≤ρt−1=ρ0t−1\max\left(\|\hat{\varepsilon}_{t-1}\|, \|\tilde{\varepsilon}_t\|\right) \leq \rho_{t-1} = \frac{\rho_0}{\sqrt{t-1}}max(ε^t1,ε~t)ρt1=t1ρ0(扰动半径衰减),得:
    sup⁡∥∇L(w)∥≤T⋅ρ0t−1+G+σ2 \sup \|\nabla L(w)\| \leq T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2sup∥∇L(w)Tt1ρ0+G+σ2

  3. 控制参数差∥ε^t−1−ε~t∥\|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\|ε^t1ε~t的上界
    由EMA定义E^t−1B=(1−β)E^t−2B+βEt−1B\hat{E}_{t-1}^B = (1-\beta)\hat{E}_{t-2}^B + \beta E_{t-1}^BE^t1B=(1β)E^t2B+βEt1B,递推可得E^t−1B\hat{E}_{t-1}^BE^t1B是历史E1B,...,Et−1BE_1^B, ..., E_{t-1}^BE1B,...,Et1B的加权和:
    E^t−1B=β∑k=1t−1(1−β)t−1−kEkB\hat{E}_{t-1}^B = \beta \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} E_k^BE^t1B=βk=1t1(1β)t1kEkB
    因此,ε^t−1=Vector(E^t−1BAt)\hat{\varepsilon}_{t-1} = \text{Vector}(\hat{E}_{t-1}^B A_t)ε^t1=Vector(E^t1BAt)ε~t=Vector(EtBAt)\tilde{\varepsilon}_t = \text{Vector}(E_t^B A_t)ε~t=Vector(EtBAt)的差的范数满足:
    ∥ε^t−1−ε~t∥≤∥ε^t−1∥+∥ε~t∥ \|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\| \leq \|\hat{\varepsilon}_{t-1}\| + \|\tilde{\varepsilon}_t\|ε^t1ε~tε^t1+ε~t
    (三角不等式)。
    其中:

    • ∥ε~t∥≤ρt=ρ0t\|\tilde{\varepsilon}_t\| \leq \rho_t = \frac{\rho_0}{\sqrt{t}}ε~tρt=tρ0(当前扰动半径);
    • ∥ε^t−1∥≤β∑k=1t−1(1−β)t−1−k⋅∥εk∥≤βρ0∑k=1t−1(1−β)t−1−k=ρ0(1−(1−β)t−1)\|\hat{\varepsilon}_{t-1}\| \leq \beta \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} \cdot \|\varepsilon_k\| \leq \beta \rho_0 \sum_{k=1}^{t-1} (1-\beta)^{t-1-k} = \rho_0 (1 - (1-\beta)^{t-1})ε^t1βk=1t1(1β)t1kεkβρ0k=1t1(1β)t1k=ρ0(1(1β)t1)(等比数列求和:∑k=0n(1−β)k=1−(1−β)n+1β\sum_{k=0}^{n} (1-\beta)^k = \frac{1 - (1-\beta)^{n+1}}{\beta}k=0n(1β)k=β1(1β)n+1);
      代入得:
      ∥ε^t−1−ε~t∥≤ρ0(1−(1−β)t−1)+ρ0t≤ρ0t+ρ0(1−β)t−1+ρ0\|\hat{\varepsilon}_{t-1} - \tilde{\varepsilon}_t\| \leq \rho_0 (1 - (1-\beta)^{t-1}) + \frac{\rho_0}{\sqrt{t}} \leq \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0ε^t1ε~tρ0(1(1β)t1)+tρ0tρ0+ρ0(1β)t1+ρ0
      (因1−(1−β)t−1≤1+ρ0(1−β)t−11 - (1-\beta)^{t-1} \leq 1 + \rho_0 (1-\beta)^{t-1}1(1β)t11+ρ0(1β)t1,简化后不影响收敛性)。
  4. 合并误差界
    将步骤2和步骤3的结果代入步骤1的不等式,最终得到:
    ∣StEMA−StSAM∣≤(T⋅ρ0t−1+G+σ2)⋅(ρ0t+ρ0(1−β)t−1+ρ0) \left| S_t^{EMA} - S_t^{SAM} \right| \leq \left( T \cdot \frac{\rho_0}{\sqrt{t-1}} + G + \sigma^2 \right) \cdot \left( \frac{\rho_0}{\sqrt{t}} + \rho_0 (1-\beta)^{t-1} + \rho_0 \right) StEMAStSAM(Tt1ρ0+G+σ2)(tρ0+ρ0(1β)t1+ρ0)

关键结论:当迭代次数t→+∞t \to +\inftyt+时,ρ0t→0\frac{\rho_0}{\sqrt{t}} \to 0tρ00(1−β)t−1→0(1-\beta)^{t-1} \to 0(1β)t10β∈(0,1)\beta \in (0,1)β(0,1)),因此误差∣StEMA−StSAM∣→0\left| S_t^{EMA} - S_t^{SAM} \right| \to 0StEMAStSAM0,即EMA扰动能无限逼近Flat-LoRA的实时扰动

七、核心推导3:EFlat-LoRA的计算复杂度验证

EFlat-LoRA的“效率”需通过参数复杂度、内存复杂度、时间复杂度的数学推导验证,证明其与LoRA效率相当。

1. 参数复杂度(训练参数数量)

EFlat-LoRA未新增可训练参数,仅复用LoRA的AAAr×mr \times mr×m)和BBBn×rn \times rn×r)矩阵,因此参数数量与LoRA、Flat-LoRA完全一致:
PEFlat-LoRA=PLoRA=PFlat-LoRA=O(nr+rm)≪O(nm)P_{\text{EFlat-LoRA}} = P_{\text{LoRA}} = P_{\text{Flat-LoRA}} = O(nr + rm) \ll O(nm) PEFlat-LoRA=PLoRA=PFlat-LoRA=O(nr+rm)O(nm)
其中n×mn \times mn×m为原模型权重矩阵维度,r≪min⁡(n,m)r \ll \min(n,m)rmin(n,m)(低秩约束),保证参数效率。

2. 内存复杂度(存储开销)

EFlat-LoRA的额外内存开销仅来自“EMA扰动E^tB\hat{E}_t^BE^tB的存储”(n×rn \times rn×r矩阵,因仅对B加扰动),需对比Flat-LoRA:

  • Flat-LoRA需存储“AAA的梯度+原A/BA/BA/B备份”,内存为MLoRA+O(1.5(nr+rm))M_{\text{LoRA}} + O(1.5(nr + rm))MLoRA+O(1.5(nr+rm))
  • EFlat-LoRA需存储“EMA扰动E^tB\hat{E}_t^BE^tB”(O(nr)O(nr)O(nr)),且现代优化器(如AdamW)已需存储“动量+二阶动量”(O(2(nr+rm))O(2(nr + rm))O(2(nr+rm))),因此EFlat-LoRA的内存为:
    MEFlat-LoRA=MLoRA+O(2(nr+rm)) M_{\text{EFlat-LoRA}} = M_{\text{LoRA}} + O(2(nr + rm)) MEFlat-LoRA=MLoRA+O(2(nr+rm))
3. 时间复杂度(训练耗时)

时间复杂度的核心是“前向-反向传播次数”:

  • LoRA:每次迭代1次前向+1次反向,时间复杂度O(T)O(T)O(T)TTT为单轮前向-反向时间);
  • Flat-LoRA:每次迭代需2次前向+2次反向(原参数点+扰动参数点),时间复杂度O(2T)O(2T)O(2T)
  • EFlat-LoRA:因用EMA扰动E^t−1B\hat{E}_{t-1}^BE^t1B近似实时扰动EtBE_t^BEtB,仅需在“更新EMA时计算1次EtBE_t^BEtB”,每次迭代仅1次前向+1次反向,时间复杂度:
    TEFlat-LoRA=O(T)=TLoRA T_{\text{EFlat-LoRA}} = O(T) = T_{\text{LoRA}} TEFlat-LoRA=O(T)=TLoRA
    (与LoRA效率完全一致,解决Flat-LoRA的耗时问题)。
http://www.dtcms.com/a/438893.html

相关文章:

  • 【面板数据】全国分省教育支出水平数据集(2007-2023年)
  • 做网站工商局要不要备案呢wordpress主题多页面
  • 济南网站建设服务公司wordpress 远程插件安装 ftp
  • 济南网站建设公司按需定制保险网站建设的目标
  • 深入剖析:boost::intrusive_ptr 与 std::shared_ptr 的性能边界和实现哲学
  • 聊城制作手机网站公司网站建设需要的条件
  • SQL 子查询与多表 JOIN 用法大全(速查版)
  • Leetcode 239. 滑动窗口最大值 优先队列 / 双向单调队列
  • Nacos 工作原理及流量走向
  • 夏津网站建设茂名企业建站程序
  • OSPF 单区域实验 概念及题目
  • 建立一个门户网站WordPress域名后问号英文
  • 自上而下VS自下而上:设计哲学全解
  • 【开题答辩全过程】以 SpringCloud家乡美旅行交流博客平台为例,包含答辩的问题和答案
  • 2015优先中文公司官网wordpress模板
  • 国外优秀企业网站网络空间的竞争归根结底是
  • 哪些外贸网站可以做soho求网站2021给个网址
  • 2022年网站能用的兰州企业网站制作
  • 网页设计与网站建设实战大全推荐好的网站或网页
  • 查看网站是否做百度推广如果在网上接网站建设项目
  • 如何用源码搭建网站源码网站搭建规划
  • 【办公类-117-01】20250924通义万相视频2.5——三个小人(幼儿作品动态化)
  • PBS, 以太坊的棘刺雕猴
  • 【未来】智能体互联时代的商业模式变化和挑战:从HOM到AOM
  • 域名免费注册网站网站模板凡建站
  • 关键词挖掘站长c 教程如何做网站
  • 爬坑 10 年总结!淘宝全量商品接口实战开发:从分页优化到数据完整性闭环
  • 网站的设计制作流程网络营销的流程
  • 网站改版计划珠宝 网站模板
  • LangChain源码分析(九)- 向量存储