当前位置: 首页 > news >正文

机器学习代码基础——ML2 使用梯度下降的线性回归

ML2 使用梯度下降的线性回归

牛客网

描述

编写一个使用梯度下降执行线性回归的 Python 函数。该函数应将 NumPy 数组 X(具有一列截距的特征)和 y(目标)作为输入,以及学习率 alpha 和迭代次数,并返回一个 NumPy 数组,表示线性回归模型的系数。

输入描述:

第1行输入X,第2行输入y,第3行输入alpha,第4行输入迭代次数。

输出描述:

输出线性回归模型的系数,四舍五入到小数点后四位。返回类型是List类型。

输入:
[[1, 1], [1, 2], [1, 3], [1, 4]]
[2, 3, 4, 5]
0.01
1000

输出: 
[0.8678 1.045 ]
import numpy as np
def linear_regression_gradient_descent(X, y, alpha, iterations):
    # 补全代码
    m,n = X.shape
    theta = np.zeros((n,1)) # 为了和答案一致
    for _ in range(iterations):
        y_predict = X@theta
        errors = y_predict - y
        discent = X.T@(errors)/m
        theta = theta - alpha * discent
    return np.round(theta.flatten(), 4)

# 主程序
if __name__ == "__main__":
    # 输入矩阵和向量
    matrix_inputx = input()
    array_y = input()
    alpha = input()
    iterations = input()

    # 处理输入
    import ast
    matrix = np.array(ast.literal_eval(matrix_inputx))
    y = np.array(ast.literal_eval(array_y)).reshape(-1,1)
    alpha = float(alpha)
    iterations = int(iterations)

    # 调用函数计算逆矩阵
    output = linear_regression_gradient_descent(matrix,y,alpha,iterations)
    
    # 输出结果
    print(output)


[0.8678 1.045 ]

梯度下降求解

梯度下降是一种计算局部最小值的一种方法。梯度下降思想就是给定一个初始值𝜃,每次沿着函数梯度下降的方向移动𝜃:

θ ( t + 1 ) : = θ ( t ) − α ∇ θ J ( θ ( t ) ) \theta^{(t+1)} := \theta^{(t)} - \alpha \nabla_{\theta} J(\theta^{(t)}) θ(t+1):=θ(t)αθJ(θ(t))

在梯度为零或趋近于零的时候收敛
J ( θ ) = 1 2 n ∑ i = 1 n ( x i T θ − y i ) 2 J(\theta)=\frac{1}{2n}\sum^n_{i=1}(x_i^T\theta-y_i)^2 J(θ)=2n1i=1n(xiTθyi)2
对损失函数求偏导可得到 (n个样本,每个样本p维)
x i = ( x i , 0 , . . . , x i , p ) T x i j 表示第 i 个样本的第 j 个分量 ∂ θ j 1 2 n ( x i T θ − y i ) 2 = ∂ θ j 1 2 n ( ∑ j = 0 p x i , j θ j − y i ) 2 = 1 n ( ∑ j = 0 p x i , j θ j − y i ) x i , j = 1 n ( f ( x i ) − y i ) ) x i , j ∇ θ J = [ J θ 0 J θ 1 . . . J θ p ] x_i=(x_{i,0},...,x_{i,p})^T\\ x_{ij}表示第i个样本的第j个分量\\ \frac{\partial}{\theta_j}\frac{1}{2n}(x_i^T\theta-y_i)^2=\frac{\partial}{\theta_j}\frac{1}{2n}(\sum^p_{j=0}x_{i,j}\theta_j-y_i)^2=\frac{1}{n}(\sum^p_{j=0}x_{i,j}\theta_j-y_i)x_{i,j}=\frac{1}{n}(f(x_i)-y_i))x_{i,j} \\ \nabla_\theta J=\begin{bmatrix} \frac{J}{\theta_0}\\ \frac{J}{\theta_1}\\...\\ \frac{J}{\theta_p} \end{bmatrix} xi=(xi,0,...,xi,p)Txij表示第i个样本的第j个分量θj2n1(xiTθyi)2=θj2n1(j=0pxi,jθjyi)2=n1(j=0pxi,jθjyi)xi,j=n1(f(xi)yi))xi,jθJ= θ0Jθ1J...θpJ
对于只有一个训练样本的训练组而言,每走一步,𝜃𝑗(𝑗= 0,1,…,𝑝)的更新公式就可以写成:
θ j ( t + 1 ) : = θ j ( t ) − α ∂ ∂ θ j J ( θ j ( t ) ) = θ j ( t ) − α 1 n ( f ( x i ) − y i ) x i , j \theta_j^{(t+1)} := \theta_j^{(t)} - \alpha \frac{\partial}{\partial \theta_j} J(\theta_j^{(t)}) = \theta_j^{(t)} - \alpha \frac{1}{n} (f(x_i) - y_i) x_{i,j} θj(t+1):=θj(t)αθjJ(θj(t))=θj(t)αn1(f(xi)yi)xi,j
因此,当有 n 个训练实例的时候(批处理梯度下降算法),该公式就可以写为:
θ j ( t + 1 ) : = θ j ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i , j \theta_j^{(t+1)}:=\theta_j^{(t)}-\alpha\frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i,j} θj(t+1):=θj(t)αn1i=1n(f(xi)yi)xi,j
这样,每次根据所有数据求出偏导,然后根据特定的步长𝛼,就可以不断更新𝜃𝑗,直到其收敛。当梯度为0或目标函数值不能继续下降的时候,就可以说已经收敛,即目标函数达到局部最小值。

具体过程可以归纳如下

1️⃣ 初始化𝜃(随机初始化)

2️⃣ 利用如下公式更新𝜃
θ j ( t + 1 ) : = θ j ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i , j θ ( t + 1 ) : = θ ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i \theta_j^{(t+1)}:=\theta_j^{(t)}-\alpha \frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i,j}\\ \theta^{(t+1)}:=\theta^{(t)}-\alpha \frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i} θj(t+1):=θj(t)αn1i=1n(f(xi)yi)xi,jθ(t+1):=θ(t)αn1i=1n(f(xi)yi)xi
其中α为步长

3️⃣ 如果新的𝜃能使𝐽(𝜃)继续减少,继续利用上述步骤更新𝜃,否则收敛,停止迭代。

http://www.dtcms.com/a/122099.html

相关文章:

  • 暑假实习面试复盘
  • Vue框架的Diff算法
  • 使用Ollama通过预训练模型获取句子向量(rest api方式)
  • GDB调试程序的基本命令和用法(Qt程序为例)
  • 三月份面试感触
  • OpenCV链接库失败,报错 无法解析的外部符号
  • SCI科学论文的重要组成部分
  • 达梦数据库迁移问题总结
  • 如何进行数据安全风险评估总结
  • Frida 调用 kill 命令挂起恢复 Android 线程
  • spring之JdbcTemplate、GoF之代理模式、面向切面编程AOP
  • 在Ubuntu 22.04上配置【C/C++编译环境】
  • 【码农日常】vscode编码clang-format格式化简易教程
  • (PTA) L2-011-L2-015
  • TDengine 窗口预聚集
  • 面试如何应用大模型
  • 算法刷题记录——LeetCode篇(1.6) [第51~60题](持续更新)
  • JAVA基础八股复习
  • 服务器DNS失效
  • DataGear结合AI工具制作多端适配的数据看板
  • Markdown标题序号处理工具——用 C 语言实现
  • 最新Web系统全面测试指南
  • lab-foundation开源程序AI/数据科学的瑞士军刀,开箱即用的数据科学/AI 平台 |AI/数据科学的瑞士军刀
  • java设计模式-代理模式
  • C语言操作符详解:从基础到进阶
  • Vue3中watch监视ref对象方法详解
  • win10开机启动文件夹所在位置
  • MQTT-Dashboard-数据集成
  • JS 箭头函数
  • 深度了解向量引论