当前位置: 首页 > news >正文

深度学习【迭代梯度下降法求解线性回归】


梯度下降法

梯度下降法是一种常用迭代方法,其目的是让输入向量找到一个合适的迭代方向,使得输出值能达到局部最小值。在拟合线性回归方程时,我们把损失函数视为以参数向量为输入的函数,找到其梯度下降的方向并进行迭代,就能找到最优的参数值。

1.计算对于给定的线性模型 (y = wx + b) 的均方误差(MSE)。它接受截距 (b)、斜率 (w) 和点集 (points),然后遍历所有点,计算每个点的预测值,与真实值之差的平方和,最后返回平均误差。

2.更新w和b

3.多次迭代后,得到最优的w和b,也就是y=wx+b这个模型对于给定数据集的最优

这里给定数据集:100个(x,y)

import torch
import numpy as np

#计算给定点集的线性回归的误差  y = wx + b
def compute_error_for_line_given_points(b,w,points):
    total_error = 0
    for i in range(len(points)):
        x = points[i,0]
        y = points[i,1]
        total_error += (y - (w*x + b))**2
    return total_error/float(len(points))

#梯度下降法求解线性回归  w = w - learning_rate * w_gradient, b = b - learning_rate * b_gradient
def step_gradient(b_current,w_current,points,learning_rate):
    b_gradient = 0
    w_gradient = 0
    n = float(len(points))
    for i in range(len(points)):
        x = points[i,0]
        y = points[i,1]
        b_gradient += -(2/n) * (y - ((w_current*x) + b_current))
        w_gradient += -(2/n) * x * (y - ((w_current*x) + b_current))
    new_b = b_current - (learning_rate * b_gradient)
    new_w = w_current - (learning_rate * w_gradient)
    return [new_b,new_w]

#迭代梯度下降法求解线性回归
def gradient_descent_runner(points,starting_b,starting_w,learning_rate,num_iterations):
    b = starting_b
    w = starting_w
    for i in range(num_iterations):
        b,w = step_gradient(b,w,points,learning_rate)
    return [b,w]

def run():
    points = np.genfromtxt('data.csv', delimiter=',')
    learning_rate = 0.0001
    initial_b = 0
    initial_w = 0
    num_iterations = 1000
    print("Starting gradient descent at b = {0}, w = {1}, error = {2}".format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points)))
    print("Running...")
    [b,w] = gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)
    print("After {0} iterations b = {1}, w = {2}, error = {3}".format(num_iterations,b,w,compute_error_for_line_given_points(b,w,points)))

if __name__ == '__main__':
    run()

执行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/74341.html

相关文章:

  • 在 macOS Sequoia 15.2 中启用「三指拖动」并实现快速复制的完整指南 ✨
  • 深度学习-简介
  • 学生选课管理系统数据库设计报告
  • Git下载安装(保姆教程)
  • torcharrow gflags版本问题
  • 动作捕捉手套如何让虚拟现实人机交互 “触手可及”?
  • 【入门初级篇】窗体的基本操作与功能介绍
  • 分布式唯一ID
  • Linux FILE文件操作2- fopen、fclose、fgetc、fputc、fgets、fputs验证
  • Java 大视界 -- Java 大数据机器学习模型的对抗攻击与防御技术研究(137)
  • 【嵌入式】复刻SQFMI开源的Watchy墨水屏电子表——(2)软件部分
  • Git 的使用上传下载和更新
  • 【数学 线性代数】差分约束
  • Python----计算机视觉处理(Opencv:图像颜色替换)
  • 三维重建(十七)——obj文件解读+ply文件解读
  • 搞了搞Python,写了个图片对比程序及AI硅基流动对话
  • BFF与API Gateway的区别解析
  • Socket 、WebSocket、Socket.IO详细对比
  • Dify 搭建
  • 智能汽车图像及视频处理方案,支持视频智能包装创作能力
  • koupleless 合并多个微服务应用到一个应用实例(包含springcloud gateway)
  • w259交通管理在线服务系统设计与实现
  • Nginx限流与鉴权(Nginx Traffic Limiting and Authentication)
  • JS逆向:泛微OA的前端密码加密逆向分析,并使用Python构建泛微OA登录
  • [023-01-47].第47节:组件应用 - GetWay与 Sentinel 集成实现服务限流
  • 3.17学习总结 java数组
  • Compose 实践与探索十四 —— 自定义布局
  • 第四章 搜索基础
  • python项目一键加密,极度简洁
  • 【嵌入式硬件】三款DCDC调试笔记