当前位置: 首页 > news >正文

深度学习【迭代梯度下降法求解线性回归】


梯度下降法

梯度下降法是一种常用迭代方法,其目的是让输入向量找到一个合适的迭代方向,使得输出值能达到局部最小值。在拟合线性回归方程时,我们把损失函数视为以参数向量为输入的函数,找到其梯度下降的方向并进行迭代,就能找到最优的参数值。

1.计算对于给定的线性模型 (y = wx + b) 的均方误差(MSE)。它接受截距 (b)、斜率 (w) 和点集 (points),然后遍历所有点,计算每个点的预测值,与真实值之差的平方和,最后返回平均误差。

2.更新w和b

3.多次迭代后,得到最优的w和b,也就是y=wx+b这个模型对于给定数据集的最优

这里给定数据集:100个(x,y)

import torch
import numpy as np

#计算给定点集的线性回归的误差  y = wx + b
def compute_error_for_line_given_points(b,w,points):
    total_error = 0
    for i in range(len(points)):
        x = points[i,0]
        y = points[i,1]
        total_error += (y - (w*x + b))**2
    return total_error/float(len(points))

#梯度下降法求解线性回归  w = w - learning_rate * w_gradient, b = b - learning_rate * b_gradient
def step_gradient(b_current,w_current,points,learning_rate):
    b_gradient = 0
    w_gradient = 0
    n = float(len(points))
    for i in range(len(points)):
        x = points[i,0]
        y = points[i,1]
        b_gradient += -(2/n) * (y - ((w_current*x) + b_current))
        w_gradient += -(2/n) * x * (y - ((w_current*x) + b_current))
    new_b = b_current - (learning_rate * b_gradient)
    new_w = w_current - (learning_rate * w_gradient)
    return [new_b,new_w]

#迭代梯度下降法求解线性回归
def gradient_descent_runner(points,starting_b,starting_w,learning_rate,num_iterations):
    b = starting_b
    w = starting_w
    for i in range(num_iterations):
        b,w = step_gradient(b,w,points,learning_rate)
    return [b,w]

def run():
    points = np.genfromtxt('data.csv', delimiter=',')
    learning_rate = 0.0001
    initial_b = 0
    initial_w = 0
    num_iterations = 1000
    print("Starting gradient descent at b = {0}, w = {1}, error = {2}".format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points)))
    print("Running...")
    [b,w] = gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)
    print("After {0} iterations b = {1}, w = {2}, error = {3}".format(num_iterations,b,w,compute_error_for_line_given_points(b,w,points)))

if __name__ == '__main__':
    run()

执行结果:

相关文章:

  • 在 macOS Sequoia 15.2 中启用「三指拖动」并实现快速复制的完整指南 ✨
  • 深度学习-简介
  • 学生选课管理系统数据库设计报告
  • Git下载安装(保姆教程)
  • torcharrow gflags版本问题
  • 动作捕捉手套如何让虚拟现实人机交互 “触手可及”?
  • 【入门初级篇】窗体的基本操作与功能介绍
  • 分布式唯一ID
  • Linux FILE文件操作2- fopen、fclose、fgetc、fputc、fgets、fputs验证
  • Java 大视界 -- Java 大数据机器学习模型的对抗攻击与防御技术研究(137)
  • 【嵌入式】复刻SQFMI开源的Watchy墨水屏电子表——(2)软件部分
  • Git 的使用上传下载和更新
  • 【数学 线性代数】差分约束
  • Python----计算机视觉处理(Opencv:图像颜色替换)
  • 三维重建(十七)——obj文件解读+ply文件解读
  • 搞了搞Python,写了个图片对比程序及AI硅基流动对话
  • BFF与API Gateway的区别解析
  • Socket 、WebSocket、Socket.IO详细对比
  • Dify 搭建
  • 智能汽车图像及视频处理方案,支持视频智能包装创作能力
  • 阿森纳被打得毫无脾气,回天无力的阿尔特塔只剩嘴硬
  • 视频|漫画家寂地:古老丝路上的文化与交流留下的独特印记
  • 深圳下调公积金利率,209万纯公积金贷款总利息减少9.94万
  • 明查|这是“C919迫降在农田”?实为飞机模型将用于科普体验
  • 外交部回应西班牙未来外交战略:愿与之一道继续深化开放合作
  • 债券市场“科技板”来了:哪些机构能尝鲜,重点支持哪些领域