当前位置: 首页 > news >正文

最小二乘求解器lstsq,处理带权重和L2正则的线性回归

目录

代码注释版:

关键功能说明:

torch.linalg.cholesky 的原理

代码示例

Cholesky 分解的应用

与 torch.cholesky 的区别

总结


代码注释版:

from typing import Optional

import torch


def lstsq(
    matrix: torch.Tensor, 
    rhs: torch.Tensor, 
    weights: torch.Tensor, 
    l2_regularizer: Optional[torch.Tensor] = None,
    l2_regularizer_rhs: Optional[torch.Tensor] = None,
    shared: bool = False
) -> torch.Tensor:
    """带权重和L2正则化的最小二乘求解器,使用Cholesky分解
    
    解决形如 (A^T W A + λI) x = A^T W b 的线性系统
    支持多任务共享参数(通过shared参数合并Gram矩阵和右侧项)
    
    Args:
        matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]
        rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]
        weights: 权重矩阵W的对角元素,形状为 [batch_size, n_obs]
        l2_regularizer: L2正则化项λ的对角矩阵,形状为 [batch_size, n_params, n_params]
        l2_regularizer_rhs: 正则化项对右侧的修正,形状为 [batch_size, n_params, n_outputs]
        shared: 是否共享参数(将多个系统的Gram矩阵和右侧项求和)
    
    Returns:
        最小二乘解,形状为 [batch_size, n_params, n_outputs]
    """
    # 加权设计矩阵: W^(1/2) * A
    weighted_matrix = weights.unsqueeze(-1) * matrix
    
    # 计算正则化的Gram矩阵: A^T W A + λI
    regularized_gramian = weighted_matrix.mT @ matrix
    if l2_regularizer is not None:
        regularized_gramian += l2_regularizer  # 添加L2正则项
    
    # 计算右侧项: A^T W b + λ_rhs
    ATb = weighted_matrix.mT @ rhs
    if l2_regularizer_rhs is not None:
        ATb += l2_regularizer_rhs
    
    # 如果共享参数,合并所有batch的贡献
    if shared:
        regularized_gramian = regularized_gramian.sum(dim=0, keepdim=True)
        ATb = ATb.sum(dim=0, keepdim=True)
    
    # Cholesky分解求解
    chol = torch.linalg.cholesky(regularized_gramian)
    return torch.cholesky_solve(ATb, chol)


def lstsq_partial_share(
    matrix: torch.Tensor,
    rhs: torch.Tensor,
    weights: torch.Tensor,
    l2_regularizer: torch.Tensor,
    n_shared: int = 0
) -> torch.Tensor:
    """部分参数共享的最小二乘求解器
    
    将参数分为共享部分和独立部分:
    - 共享参数在所有样本间共享
    - 独立参数每个样本单独估计
    通过分块回归实现高效求解
    
    Args:
        matrix: 设计矩阵A,形状为 [batch_size, n_obs, n_params]
        rhs: 右侧项b,形状为 [batch_size, n_obs, n_outputs]
        weights: 权重矩阵的对角元素,形状为 [batch_size, n_obs]
        l2_regularizer: 正则化强度,形状为 [batch_size, n_params]
        n_shared: 共享参数的数量
    
    Returns:
        参数矩阵,前n_shared列为共享参数,其余为独立参数
        形状为 [batch_size, n_params, n_outputs]
    """
    n_params = matrix.shape[-1]
    n_rhs_outputs = rhs.shape[-1]
    n_indep = n_params - n_shared

    # 全共享情况直接返回广播结果
    if n_indep == 0:
        result = lstsq(matrix, rhs, weights, l2_regularizer, shared=True)
        return result.expand(matrix.shape[0], -1, -1)

    # 将正则化项转换为设计矩阵的扩展部分
    # 相当于添加 λI 的正则化项
    matrix = torch.cat([matrix, batch_eye(n_params, matrix.shape[0])], dim=1)
    rhs = torch.nn.functional.pad(rhs, (0, 0, 0, n_params))  # 右侧添加0
    weights = torch.cat([weights, l2_regularizer.unsqueeze(0).expand(matrix.shape[0], -1)], dim=1)

    # 分割共享和独立参数对应的设计矩阵
    matrix_shared, matrix_indep = torch.split(matrix, [n_shared, n_indep], dim=-1)

    # 步骤1:求解独立参数对共享参数和输出的影响
    indep_coeffs = lstsq(matrix_indep, torch.cat([matrix_shared, rhs], dim=-1), weights)
    coeff_indep2shared, coeff_indep2rhs = torch.split(indep_coeffs, [n_shared, n_rhs_outputs], dim=-1)

    # 步骤2:用残差求解共享参数
    shared_residual = matrix_shared - matrix_indep @ coeff_indep2shared
    rhs_residual = rhs - matrix_indep @ coeff_indep2rhs
    coeff_shared2rhs = lstsq(shared_residual, rhs_residual, weights, shared=True)

    # 步骤3:更新独立参数系数
    coeff_indep2rhs = coeff_indep2rhs - coeff_indep2shared @ coeff_shared2rhs

    # 合并结果:共享参数广播,独立参数保持独立
    coeff_shared2rhs = coeff_shared2rhs.expand(matrix.shape[0], -1, -1)
    return torch.cat([coeff_shared2rhs, coeff_indep2rhs], dim=1)


def batch_eye(n_params: int, batch_size: int) -> torch.Tensor:
    """生成批次对角矩阵
    
    Args:
        n_params: 矩阵维度
        batch_size: 批次大小
    
    Returns:
        形状为 [batch_size, n_params, n_params] 的单位矩阵批次
    """
    return torch.eye(n_params).reshape(1, n_params, n_params).expand(batch_size, -1, -1)

关键功能说明:

  1. lstsq:

    • 核心最小二乘求解器,处理带权重和L2正则的线性回归

    • 使用Cholesky分解提高数值稳定性

    • 支持多任务参数共享模式(shared=True时合并所有任务的贡献)

  2. lstsq_partial_share:

    • 处理部分参数共享的回归问题

    • 通过三步分块回归实现:

      1. 估计独立参数对共享参数和输出的影响

      2. 用残差估计共享参数

      3. 修正独立参数估计值

    • 通过矩阵拼接技巧将正则化转换为设计矩阵扩展

  3. batch_eye:

    • 生成批次单位矩阵,用于构建正则化项

    • 典型应用:将L2正则转换为扩展设计矩阵的伪观测

torch.linalg.cholesky 的原理

torch.linalg.cholesky(A) 用于对对称正定矩阵 AAA 进行 Cholesky 分解,即将其分解为:

A=LLTA = L L^TA=LLT

其中:

  • AAA 是 对称正定矩阵(必须满足 A=ATA = A^TA=AT 且所有特征值大于 0)。

  • LLL 是 下三角矩阵

计算 Cholesky 分解 的方式基于逐行计算 LLL:

  1. 计算对角元素:

    Lii=Aii−∑k=1i−1Lik2L_{ii} = \sqrt{ A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2 }Lii​=Aii​−k=1∑i−1​Lik2​​
  2. 计算非对角元素:

    Lji=1Lii(Aji−∑k=1i−1LjkLik),j>iL_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right), \quad j > iLji​=Lii​1​(Aji​−k=1∑i−1​Ljk​Lik​),j>i

这个算法 只需要计算下三角部分,所以比 LU 分解 计算量更少,适用于 正定矩阵的快速求解


代码示例

import torch

# 生成一个对称正定矩阵
A = torch.tensor([[4.0, 12.0, -16.0], 
                  [12.0, 37.0, -43.0], 
                  [-16.0, -43.0, 98.0]])

# Cholesky 分解
L = torch.linalg.cholesky(A)
print(L)

输出

tensor([[ 2.0000, 0.0000, 0.0000],

[ 6.0000, 1.0000, 0.0000],

[-8.0000, 5.0000, 3.0000]])

可以验证:

print(torch.mm(L, L.T))

# 结果应当等于 A

Cholesky 分解的应用

  1. 解线性方程组 Ax=bAx = bAx=b:

    • 先求 L = torch.linalg.cholesky(A)

    • Ly = b(前代法)

    • L^T x = y(后代法)

  2. 生成多元正态分布

    • 如果协方差矩阵 Σ\SigmaΣ 进行 Cholesky 分解 Σ=LLT\Sigma = L L^TΣ=LLT,

    • 则可以用 L @ torch.randn(n, d) 生成符合协方差 Σ\SigmaΣ 的多元正态分布数据。


torch.cholesky 的区别

  • torch.cholesky(A) 旧版 API,不推荐使用。

  • torch.linalg.cholesky(A) 现代 API,支持 batch 计算,推荐使用。


总结

  • torch.linalg.cholesky(A) 计算 对称正定矩阵Cholesky 分解,分解成下三角矩阵 L,使得 A=LLTA = L L^TA=LLT。

  • 计算方式比 LU 分解更快,主要用于 正定矩阵的求解、统计学、多元正态分布 等。

  • 使用 Cholesky 分解求解线性方程组比直接求逆更稳定高效。

http://www.dtcms.com/a/108154.html

相关文章:

  • Vue3 + Element Plus + AntV X6 实现拖拽树组件
  • 【人工智能之大模型】如何缓解大语言模型LLMs重复读的问题?
  • 函数ioctl(Input/Output Control)
  • mac如何将jar包上传到maven中央仓库中
  • LeetCode-695. 岛屿的最大面积
  • Linux系统之systemctl管理服务及编译安装配置文件安装实现systemctl管理服务
  • Redis-10.在Java中操作Redis-Spring Data Redis使用方式-操作步骤说明
  • 基于随机森林算法的信用风险评估项目
  • 汇编学习结语
  • Dify案例-接入飞书云文档实现需求质量评估
  • MongoDB文档操作
  • 基于HTML5的音乐播放器(源码+lw+部署文档+讲解),源码可白嫖!
  • vscode代码片段的设置与使用
  • 填坑日志(20250402)解决Jira Rest API出现403XSRF check failed报错的问题
  • Ansible(4)—— Playbook
  • STL 性能优化实战:解决项目中标准模板库的性能瓶颈
  • C语言跳表(Skip List)算法:数据世界的“时光穿梭机”
  • Node.js v22.14.0 多平台安装指南:Windows、Linux 和 macOS 详细教程
  • 当AI开始“思考“:大语言模型的文字认知三部曲
  • Vue 中 this.$emit(“update:xx“,value) 和 :xx.sync 实现同步数据的做法
  • 创建灵活可配置的轮播图组件: GrapesJS 与 Vue3 的完美结合
  • 超短波通信模拟设备:增强通信能力的关键工具
  • 【3.软件工程】3.2 瀑布模型
  • MySQL 高级查询:JOIN、子查询、窗口函数
  • 3D AI 公司 VAST 开源基础 3D 生成模型 TripoSG 和 TripoSF
  • nocobase + Python爬虫实现数据可视化
  • 超详细!!!一文理解Prompting Depth Anything(CVPR2025)
  • 使用Docker安装及使用最新版本的Jenkins
  • Unity打包webgl本地测试
  • 无人机机体结构设计要点与难点!