当前位置: 首页 > news >正文

机器学习算法-一元线性回归(最小二乘拟合 and 梯度下降)

先来看看理论知识吧

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Python实现

import numpy as np
import sympy as spX = [1, 2, 3, 4, 5]
Y = [1, 3, 6, 7, 6]class LinerRegress:def __init__(self, learn_rate=0.001, gradient_times=100):# 训练特征数据self.train_x = []# 训练标签数据self.train_y = []# 线性回归斜率self.w = 0# 线性回归截距self.b = 0# 学习率self.learn_rate = learn_rate# 梯度下降次数self.gradient_times = gradient_timesdef fit(self, X, Y):self.train_x, self.train_y = np.array(X), np.array(Y)w,b = self.min_twicemultiply(self.train_x,self.train_y)self.w, self.b = self.gradient_descend(self.train_x, self.train_y, w, b, self.gradient_times)# 最小二乘拟合确定 w 和 bdef min_twicemultiply(self,X,Y):x0, x1 = np.ones(len(X)), np.array(X)x0x0, x1x1 = np.dot(x0, x0), np.dot(x1, x1)x1x0 = x0x1 = np.dot(x0, x1)yx0, yx1 = np.dot(self.train_y, x0), np.dot(self.train_y, x1)w = (x1x0 * yx0 - x0x0 * yx1) / (x0x1 * x1x0 - x0x0 * x1x1)b = (x1x1 * yx0 - x0x1 * yx1) / (x1x1 * x0x0 - x1x0 * x0x1)return w,b# 梯度下降调优参数 w 和 bdef gradient_descend(self, x, y, w, b, times):_w, _b = sp.symbols('_w _b')size = len(y)# 损失函数loss_f1 同时对w和b做偏导数求极小值# 这里损失函数loss_f2只对截距w做了"梯度下降"算法调整loss_f1, loss_f2 = 0.0, 0.0for xv, yv in zip(x, y):loss_f1 += (yv - _w * xv - _b) ** 2loss_f2 += (yv - _w * xv - b) ** 2loss_f1, loss_f2 = sp.simplify(loss_f1 / (2 * size)), sp.simplify(loss_f2 / (2 * size))res = sp.solve([sp.diff(loss_f1, _w), sp.diff(loss_f1, _b)], [_w, _b])# loss_f1 的w和b的极值点存在if len(res) != 0:return float(res.get(_w)), float(res.get(_b))else:next_w = wloss_f2_diff = sp.diff(loss_f2, _w)for _ in range(times):next_w -= loss_f2_diff.subs({_w: next_w}) * self.learn_ratereturn next_w, bdef predict(self, data):return self.w * data + self.bmodel = LinerRegress()
model.fit(X, Y)
print(model.predict(6))

相关文章:

  • java三种常见设计模式,工厂、策略、责任链
  • OWASP Juice-Shop靶场(⭐⭐)
  • aws(学习笔记第四十二课) serverless-backend
  • 2025年5月系分论文题(回忆版)
  • 为什么size_t重要,size_t的大小
  • 理论物理:为什么在极低温(接近绝对零度)时,经典理论失效?
  • 并发编程艺术--AQS底层源码解析(二)
  • 多线程的基础知识以及应用
  • 计算机视觉---YOLOv2
  • 2021年认证杯SPSSPRO杯数学建模B题(第二阶段)依巴谷星表中的毕星团求解全过程文档及程序
  • 计算机网络学习(六)——UDP
  • Go语言Map的底层原理
  • mysql都有哪些锁?
  • Java并发编程:全面解析锁策略、CAS与synchronized优化机制
  • 基于SpringBoot的校园电竞赛事系统
  • uni-app学习笔记十二-vue3中组件传值(属性传值)
  • Redis之金字塔模型分层架构
  • [医学影像 AI] 使用 PyTorch 和 MedicalZooPytorch 实现 3D 医学影像分割
  • Linux Kernel调试:强大的printk(二)
  • 两个mysql的maven依赖要用哪个?
  • 上海建设工程信息网查询/谷歌seo网站优化
  • 虚拟空间能建多个网站/关键词调词平台
  • 长春电商网站建设公司排名/宁波seo网络推广产品服务
  • 甜橙直播/济南seo网站排名关键词优化
  • 罗湖网站建设的公司哪家好/网站友情链接
  • 网站接入服务单位名称/南宁优化网站网络服务