Bi-LoRA的数学推导
Bi-LoRA的数学推导
1. 模型结构定义
设预训练模型的某层权重为W0∈Rd×k\mathbf{W}_0\in\mathbb{R}^{d\times k}W0∈Rd×k(ddd为输出维度,kkk为输入维度)。Bi-LoRA引入两个低秩模块对W0\mathbf{W}_0W0进行扰动:
- 主模块:ΔW1=B1A1\Delta\mathbf{W}_1 = \mathbf{B}_1\mathbf{A}_1ΔW1=B1A1,其中A1∈Rr×k\mathbf{A}_1\in\mathbb{R}^{r\times k}A1∈Rr×k,B1∈Rd×r\mathbf{B}_1\in\mathbb{R}^{d\times r}B1∈Rd×r(r≪min(d,k)r\ll\min(d,k)r≪min(d,k)为秩),用于任务适配。
- 辅助模块:ΔW2=B2A2\Delta\mathbf{W}_2 = \mathbf{B}_2\mathbf{A}_2ΔW2=B2A2,其中A2∈Rr×k\mathbf{A}_2\in\mathbb{R}^{r\times k}A2∈Rr×k,B2∈Rd×r\mathbf{B}_2\in\mathbb{R}^{d\times r}B2∈Rd×r,用于模拟对抗性扰动。
模型最终输出为:
y=(W0+ΔW1)x\mathbf{y} = (\mathbf{W}_0 + \Delta\mathbf{W}_1)\mathbf{x}y=(W0+ΔW1)x
(推理时移除辅助模块,仅保留主模块与预训练权重)
2. 损失函数构造
Bi-LoRA的核心是通过辅助模块模拟SAM的“对抗性扰动”,同时优化主模块的任务适配性。定义损失函数为:
LBi-LoRA=L(y∗,y1)+λ⋅L(y∗,y1+Δy2)\mathcal{L}_{\text{Bi-LoRA}} = \mathcal{L}(\mathbf{y}^*,\mathbf{y}_1) + \lambda\cdot\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1+\Delta\mathbf{y}_2)LBi-LoRA=L(y∗,y1)+λ⋅L(y∗,y1+Δy2)
其中:
- y∗\mathbf{y}^*y∗为真实标签,L\mathcal{L}L为任务损失(如交叉熵)。
- y1=(W0+ΔW1)x\mathbf{y}_1 = (\mathbf{W}_0 + \Delta\mathbf{W}_1)\mathbf{x}y1=(W0+ΔW1)x为主模块输出。
- Δy2=ΔW2x=B2A2x\Delta\mathbf{y}_2 = \Delta\mathbf{W}_2\mathbf{x} = \mathbf{B}_2\mathbf{A}_2\mathbf{x}Δy2=ΔW2x=B2A2x为辅助模块引入的扰动输出。
- λ>0\lambda>0λ>0为平衡系数,控制扰动影响权重。
3. 参数更新规则
参数更新分主模块(A1,B1\mathbf{A}_1,\mathbf{B}_1A1,B1)和辅助模块(A2,B2\mathbf{A}_2,\mathbf{B}_2A2,B2)两部分,均基于梯度下降/上升:
3.1 主模块更新(梯度下降,最小化损失)
主模块参数通过LBi-LoRA\mathcal{L}_{\text{Bi-LoRA}}LBi-LoRA对A1,B1\mathbf{A}_1,\mathbf{B}_1A1,B1的梯度更新:
A1←A1−η⋅∇A1LBi-LoRA\mathbf{A}_1 \leftarrow \mathbf{A}_1 - \eta\cdot\nabla_{\mathbf{A}_1}\mathcal{L}_{\text{Bi-LoRA}}A1←A1−η⋅∇A1LBi-LoRA
B1←B1−η⋅∇B1LBi-LoRA\mathbf{B}_1 \leftarrow \mathbf{B}_1 - \eta\cdot\nabla_{\mathbf{B}_1}\mathcal{L}_{\text{Bi-LoRA}}B1←B1−η⋅∇B1LBi-LoRA
其中η\etaη为学习率,梯度计算为:
∇A1LBi-LoRA=∇A1L(y∗,y1)+λ⋅∇A1L(y∗,y1+Δy2)\nabla_{\mathbf{A}_1}\mathcal{L}_{\text{Bi-LoRA}} = \nabla_{\mathbf{A}_1}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1) + \lambda\cdot\nabla_{\mathbf{A}_1}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1+\Delta\mathbf{y}_2)∇A1LBi-LoRA=∇A1L(y∗,y1)+λ⋅∇A1L(y∗,y1+Δy2)
∇B1LBi-LoRA=∇B1L(y∗,y1)+λ⋅∇B1L(y∗,y1+Δy2)\nabla_{\mathbf{B}_1}\mathcal{L}_{\text{Bi-LoRA}} = \nabla_{\mathbf{B}_1}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1) + \lambda\cdot\nabla_{\mathbf{B}_1}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1+\Delta\mathbf{y}_2)∇B1LBi-LoRA=∇B1L(y∗,y1)+λ⋅∇B1L(y∗,y1+Δy2)
3.2 辅助模块更新(梯度上升,最大化扰动损失)
辅助模块需模拟“使损失增大的扰动”,故通过负梯度更新(等价于梯度上升):
A2←A2+η⋅∇A2L(y∗,y1+Δy2)\mathbf{A}_2 \leftarrow \mathbf{A}_2 + \eta\cdot\nabla_{\mathbf{A}_2}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1+\Delta\mathbf{y}_2)A2←A2+η⋅∇A2L(y∗,y1+Δy2)
B2←B2+η⋅∇B2L(y∗,y1+Δy2)\mathbf{B}_2 \leftarrow \mathbf{B}_2 + \eta\cdot\nabla_{\mathbf{B}_2}\mathcal{L}(\mathbf{y}^*,\mathbf{y}_1+\Delta\mathbf{y}_2)B2←B2+η⋅∇B2L(y∗,y1+Δy2)
(仅依赖扰动损失项,目标是增强对不稳定参数区域的探测)
4. 辅助模块的范数约束(扰动控制)
为限制辅助模块的扰动幅度,每次更新后需对其进行归一化。设模型有NNN个LoRA层,定义第jjj层辅助模块的弗罗贝尼乌斯范数为∥ΔW2(j)∥F=∥B2(j)A2(j)∥F\|\Delta\mathbf{W}_2^{(j)}\|_F = \|\mathbf{B}_2^{(j)}\mathbf{A}_2^{(j)}\|_F∥ΔW2(j)∥F=∥B2(j)A2(j)∥F,总范数为:
cnorm=∑j=1N∥ΔW2(j)∥F2c_{\text{norm}} = \sqrt{\sum_{j=1}^N\|\Delta\mathbf{W}_2^{(j)}\|_F^2}cnorm=j=1∑N∥ΔW2(j)∥F2
若cnorm>ρc_{\text{norm}} > \rhocnorm>ρ(ρ\rhoρ为预设邻域半径),则对所有层的A2(j),B2(j)\mathbf{A}_2^{(j)},\mathbf{B}_2^{(j)}A2(j),B2(j)进行缩放:
B2(j)←ρ/cnorm⋅B2(j)\mathbf{B}_2^{(j)} \leftarrow \sqrt{\rho/c_{\text{norm}}}\cdot\mathbf{B}_2^{(j)}B2(j)←ρ/cnorm⋅B2(j)
A2(j)←ρ/cnorm⋅A2(j)\mathbf{A}_2^{(j)} \leftarrow \sqrt{\rho/c_{\text{norm}}}\cdot\mathbf{A}_2^{(j)}A2(j)←ρ/cnorm⋅A2(j)
缩放后满足∑j=1N∥ΔW2(j)∥F2=ρ\sum_{j=1}^N\|\Delta\mathbf{W}_2^{(j)}\|_F^2 = \rho∑j=1N∥ΔW2(j)∥F2=ρ,确保扰动被约束在ρ\rhoρ-范数球内。
5. 收敛性保证
通过主模块与辅助模块的协同优化,Bi-LoRA的损失函数满足:
LBi-LoRA≥L(y∗,y1)−λ⋅ρ⋅∥∇WL∥max\mathcal{L}_{\text{Bi-LoRA}} \geq \mathcal{L}(\mathbf{y}^*,\mathbf{y}_1) - \lambda\cdot\rho\cdot\|\nabla_{\mathbf{W}}\mathcal{L}\|_{\text{max}}LBi-LoRA≥L(y∗,y1)−λ⋅ρ⋅∥∇WL∥max
其中∥∇WL∥max\|\nabla_{\mathbf{W}}\mathcal{L}\|_{\text{max}}∥∇WL∥max为权重梯度的最大模。该式表明,辅助模块的扰动被严格控制,主模块能在“平坦区域”收敛,最终提升模型泛化性。