当前位置: 首页 > news >正文

大模型微调之LORA核心逻辑

目录

  • 1. 预备知识
    • 1.1 低秩矩阵分解
  • 2. 模拟举例

1. 预备知识

1.1 低秩矩阵分解

低秩矩阵分解是一种将高维矩阵近似为两个低维矩阵乘积的技术,常用于数据降维、压缩、推荐系统等领域。

步骤1:理解目标
我们有一个高维矩阵 ΔW\Delta WΔW ,希望将其近似为两个低维矩阵AAABBB的乘积,即ΔW≈BA\Delta W \approx BAΔWBA

步骤2:设定矩阵维度
假设 ΔW\Delta WΔW是一个d×dd \times dd×d的矩阵。我们选择一个较小的整数 rrr,使得 r≪dr \ll drd 。矩阵 AAA 的维度将是 d×rd \times rd×r ,矩阵 BBB 的维度将是 r×dr \times dr×d

步骤3:矩阵初始化

  • 初始化矩阵 AAABBB 。可以使用随机初始化、正态分布初始化等方法。例如:
    • A∼N(0,σ2)A \sim \mathcal{N}(0, \sigma^2)AN(0,σ2) ,表示矩阵 AAA 的每个元素都是从均值为0、方差为 σ2\sigma^2σ2 的正态分布中随机抽取的。
    • BBB 初始化为零矩阵,即 B=0B = 0B=0

步骤4:矩阵乘积

  • 通过矩阵乘积 BABABA ,可以得到一个近似的 d×dd \times dd×d 矩阵:

    W’ = BA其中 W′≈ΔWW' \approx \Delta WWΔW

步骤5:优化和训练

  • 在训练过程中,通过优化算法(如梯度下降),不断调整矩阵 AAABBB 的值,使得 W′W'W 更加接近于 ΔW\Delta WΔW
  • 损失函数通常是衡量 ΔW\Delta WΔWW′W'W 之间差距的一个函数,例如均方误差:
    在这里插入图片描述

步骤6:更新规则

  • 通过优化算法计算损失函数关于 AAABBB 的梯度,并更新 AAABBB 的值。例如,使用梯度下降法更新规则如下:

在这里插入图片描述

其中 η\etaη 是学习率。

2. 模拟举例

import numpy as np# 初始化矩阵 W
W = np.array([[4, 3, 2, 1],[2, 2, 2, 2],[1, 3, 4, 2],[0, 1, 2, 3]])
# 矩阵维度
d = W.shape[0] # 4# 秩
r = 2# 随机初始化 A 和 B
np.random.seed(666)# A 和 B 的元素服从标准正态分布
A = np.random.randn(d, r)
B = np.zeros((r, d))
array([[ 0.82418808,  0.479966  ],[ 1.17346801,  0.90904807],[-0.57172145, -0.10949727],[ 0.01902826, -0.94376106]])
array([[0., 0., 0., 0.],[0., 0., 0., 0.]])
# 定义超参数lr = 0.01 # 学习率,用于控制梯度下降的步长。epochs = 1000 # 迭代次数,进行多少次梯度下降更新。
# 定义损失函数def loss_function(W, A, B):'''W:目标矩阵A:矩阵分解中的一个矩阵,通常是随机初始化的。B:矩阵分解中的另一个矩阵,通常是零矩阵初始化的。'''# 矩阵相乘,@是Python中的矩阵乘法运算符,相当于np.matmul(A, B)。W_approx = A @ B# 损失函数越小,表示 A 和 B 的乘积 W_approx越接近于目标矩阵 Wreturn np.linalg.norm(W - W_approx, "fro")**2 #使用均方误差# 定义梯度下降更新
def descent(W, A, B, lr, epochs):'''梯度下降法'''# 用于记录损失值loss_history = []for i in range(epochs):# 计算梯度W_approx = A @ B# 计算损失函数关于矩阵A的梯度gd_A = -2 * (W - W_approx) @ B.T# 计算损失函数关于矩阵B的梯度gd_B = -2 * A.T @ ( W - W_approx)# 使用梯度下降更新矩阵A和BA -= lr * gd_AB -= lr * gd_B# 计算当前损失loss = loss_function(W, A, B)loss_history.append(loss)# 每100个epoch打印一次if i % 100 == 0:print(f"Epoch: {i} , 损失: {loss:.4f}")# 返回优化后的矩阵return A, B, loss_history# 进行梯度下降优化A, B, loss_history = descent(W, A, B, lr, epochs)
Epoch: 0 , 损失: 87.6534
Epoch: 100 , 损失: 2.3620
Epoch: 200 , 损失: 2.3566
Epoch: 300 , 损失: 2.3566
Epoch: 400 , 损失: 2.3566
Epoch: 500 , 损失: 2.3566
Epoch: 600 , 损失: 2.3566
Epoch: 700 , 损失: 2.3566
Epoch: 800 , 损失: 2.3566
Epoch: 900 , 损失: 2.3566

可以看到loss一直在下降,但是由于这个例子比较简单,loss到2.3620就不再下降了

# 绘制损失曲线
import matplotlib.pyplot as pltplt.figure(figsize=(8,6))
plt.plot(loss_history, label="loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()
plt.grid(True)
plt.show()

在这里插入图片描述
从损失曲线也可以看出,loss是一直下降的,到达最低点后保持不变了

# 最终的近似矩阵
W_approx = A @ Bprint(W_approx)
[[ 3.92499196  3.06584542  2.10616302  0.84487308][ 1.80749375  2.16899061  2.27246474  1.60187065][ 1.39233235  2.65559308  3.44471033  2.81139716][-0.31000446  1.27213581  2.43876645  2.35886822]]
# 原始的矩阵 Wprint(W)
[[4 3 2 1][2 2 2 2][1 3 4 2][0 1 2 3]]

可以看到近似矩阵和原始矩阵值十分接近,所以说这个微调效果还是比较好的


文章转载自:

http://2eCI5mS9.pxLqL.cn
http://0DRxe1jl.pxLqL.cn
http://xVSqdHBh.pxLqL.cn
http://NUYqj3KA.pxLqL.cn
http://SWOwLPIs.pxLqL.cn
http://atSg3lA3.pxLqL.cn
http://VPUniQK8.pxLqL.cn
http://0eCwemvr.pxLqL.cn
http://tpptl62J.pxLqL.cn
http://QwLsfOLs.pxLqL.cn
http://KKF2eTW3.pxLqL.cn
http://9knq5x8f.pxLqL.cn
http://yODwGcg7.pxLqL.cn
http://2HNEoxJX.pxLqL.cn
http://Y4VDe8Eg.pxLqL.cn
http://gTGhcaoB.pxLqL.cn
http://CMOGpWa3.pxLqL.cn
http://F5AtRqd7.pxLqL.cn
http://wNldHXxK.pxLqL.cn
http://8tbI8JvW.pxLqL.cn
http://N5MT5Ots.pxLqL.cn
http://0a0pGUGm.pxLqL.cn
http://r9hedJnu.pxLqL.cn
http://uPd7vzZ3.pxLqL.cn
http://CjWMKCGr.pxLqL.cn
http://EtFR90e2.pxLqL.cn
http://hVFtV70s.pxLqL.cn
http://QHltw4pM.pxLqL.cn
http://yhz82kZG.pxLqL.cn
http://FznKWhUm.pxLqL.cn
http://www.dtcms.com/a/366967.html

相关文章:

  • React笔记_组件之间进行数据传递
  • 《Java餐厅的待客之道:BIO, NIO, AIO三种服务模式的进化》
  • 【OpenHarmony文件管理子系统】文件访问接口解析
  • sealos部署k8s
  • (C题|NIPT 的时点选择与胎儿的异常判定)2025年高教杯全国大学生数学建模国赛解题思路|完整代码论文集合
  • 25高教社杯数模国赛【C题国一学长思路+问题分析】第二弹
  • 数学建模25c
  • 互联网大厂Java面试场景与问题解答
  • LeetCode 刷题【64. 最小路径和】
  • Rust+slint实现一个登录demo
  • Rust 文件操作终极实战指南:从基础读写到进阶锁控,一文搞定所有 IO 场景
  • 代码随想录算法训练营第二十八天 | 买卖股票的最佳实际、跳跃游戏、K次取反后最大化的数组和
  • 2025全国大学生数学建模C题保姆级思路模型(持续更新):NIPT 的时点选择与胎儿的异常判定
  • 2025反爬虫之战札记:从robots.txt到多层防御的攻防进化史
  • 23种设计模式——工厂方法模式(Factory Method Pattern)详解
  • C++ 学习与 CLion 使用:(七)if 逻辑判断和 switch 语句
  • docker中的mysql变更宿主机映射端口
  • Redis(43)Redis哨兵(Sentinel)是什么?
  • 【连载 7/9】大模型应用:大模型应用:(七)大模型使用工具(29页)【附全文阅读】
  • 从 GPT 到 LLaMA:解密 LLM 的核心架构——Decoder-Only 模型
  • 原型链和原型
  • 嵌入式学习 51单片机(3)
  • 详细学习计划
  • 深度解读《实施“人工智能+”行动的意见》:一场由场景、数据与价值链共同定义的产业升级
  • CLIP模型
  • 深度学习篇---SENet网络结构
  • JS初入门
  • 大数据开发计划表(实际版)
  • TypeScript 增强功能大纲 (相对于 ECMAScript)
  • LLAMAFACTORY:一键优化大型语言模型微调的利器