当前位置: 首页 > news >正文

Python----机器学习(线性回归:前向传播和损失函数)

一、前向传播

        前向传播是指在一个机器学习算法中,从输入到输出的信息传递过程,具体 来说,就是在数据输入后,经过一系列的运算后得到结果的过程。

1.1、前向计算

1.2、单点误差

        由上图可知,当w等于0的时候“拟合”这些散点的效果并不好,而在拖动的过程中发现w的值为1的时候与散点的拟合程度相对不错

        从图中可以看出,我们无法通过一条直线将所有的点都囊括进来,于是只能找一条线,让这些点到直线 的距离尽可能的小。点到直线的距离并不用复杂的距离公式来表示,而是通过在x值相同的 情况下,线上x对应的y与点的y值之间的差来表示。

1.3、损失函数:均方差

导入模块

import numpy as np
import matplotlib.pyplot as plt

数据聚集输入

data = np.array(
    [
        [0.8,1.0],
        [1.7,0.9],
        [2.7,2.4],
        [3.2,2.9],
        [3.7,2.8],
        [4.2,3.8],
        [4.2,2.7]
    ])
#将特征和标签(需要拟合的目标)分离
x_data=data[:,0]
y_data=data[:,1]

前向计算

#y=w*x+b
w=0.9
b=0
y_hat=w*x_data+b

单点误差

e=y_data-y_hat
print(e)

 均方误差(损失函数)

e_=(np.mean((y_data-y_hat)**2))
print(e_)

 图像绘制

fig=plt.figure(figsize=(10,5))
ax1=fig.add_subplot(1,2,1)
ax2=fig.add_subplot(1,2,2)
# 装饰坐标轴
ax1.set_xlim(0,5)
ax1.set_xlim(0,6)
ax1.set_xlabel("x axis label")
ax1.set_ylabel("y axis label")
# 绘制数据集散点
ax1.scatter(x_data,y_data,color='b')
# 计算并绘制拟合线
y_lower=w*0+b
y_upper=w*5+b
ax1.plot([0,5],[y_lower,y_upper],color='r',linewidth=3)
# 左侧图点到线的竖直线(距离)
for i,j,k in zip(x_data,y_data,y_hat):
    ax1.plot([i,i],[j,k],color='g',linestyle='-')

# 绘制右侧w和e的曲线
w_values=np.linspace(0,3,100)
e_values=[np.mean(y_data-(w_value*x_data+b))**2 for w_value in w_values]
# 在曲线上绘制w的点
ax2.plot(w_values,e_values,color='g',linestyle='-')
ax2.plot(w,e_,marker='o',color='r')
plt.show()

完整代码

import numpy as np  # 导入 NumPy 库用于数值计算  
import matplotlib.pyplot as plt  # 导入 Matplotlib 库用于数据可视化  

# 1. 数据聚集输入  
data = np.array(  # 定义一个二维 NumPy 数组,包含 x 和 y 的数据点  
    [  
        [0.8, 1.0],  
        [1.7, 0.9],  
        [2.7, 2.4],  
        [3.2, 2.9],  
        [3.7, 2.8],  
        [4.2, 3.8],  
        [4.2, 2.7]  
    ])  
# 将特征和标签(需要拟合的目标)分离  
x_data = data[:, 0]  # 提取 x 数据  
y_data = data[:, 1]  # 提取 y 数据  

# 2. 前向计算  
# y = w * x + b  
w = 0.9  # 初始化权重(斜率)  
b = 0  # 初始化偏置(截距)  
y_hat = w * x_data + b  # 计算预测的 y 值  

# 3. 单点误差  
e = y_data - y_hat  # 计算每个点的误差(真实值与预测值之差)  
print(e)  # 打印误差数组  

# 4. 均方误差(损失函数)  
e_ = np.mean((y_data - y_hat) ** 2)  # 计算均方误差(MSE)  
print(e_)  # 打印均方误差  

# 5. 图像绘制  
fig = plt.figure(figsize=(10, 5))  # 创建一个图形对象,设置图形尺寸  
ax1 = fig.add_subplot(1, 2, 1)  # 左侧子图  
ax2 = fig.add_subplot(1, 2, 2)  # 右侧子图  

# 装饰坐标轴  
ax1.set_xlim(0, 5)  # 设置 x 轴范围  
ax1.set_ylim(0, 6)  # 设置 y 轴范围(注意原语句是设置 x 轴,已更正为 y 轴)  
ax1.set_xlabel("x axis label")  # x 轴标签  
ax1.set_ylabel("y axis label")  # y 轴标签  

# 绘制数据集散点  
ax1.scatter(x_data, y_data, color='b')  # 绘制数据点,蓝色表示  

# 计算并绘制拟合线  
y_lower = w * 0 + b  # 计算拟合线在 x=0 时的 y 值  
y_upper = w * 5 + b  # 计算拟合线在 x=5 时的 y 值  
ax1.plot([0, 5], [y_lower, y_upper], color='r', linewidth=3)  # 绘制拟合线,红色表示  

# 左侧图点到线的竖直线(距离)  
for i, j, k in zip(x_data, y_data, y_hat):  
    ax1.plot([i, i], [j, k], color='g', linestyle='-')  # 绘制每个数据点到拟合线的竖直距离,绿色表示  

# 绘制右侧w和e的曲线  
w_values = np.linspace(0, 3, 100)  # 创建一个从 0 到 3 的权重值数组(100个点)  
e_values = [np.mean(y_data - (w_value * x_data + b)) ** 2 for w_value in w_values]  # 计算每个 w 值对应的均方误差  

# 在曲线上绘制w的点  
ax2.plot(w_values, e_values, color='g', linestyle='-')  # 绘制 w 与均方误差的关系曲线  
ax2.plot(w, e_, marker='o', color='r')  # 在曲线中标记当前 w 值和对应的均方误差,红点表示  

plt.show()  # 展示绘图结果  
http://www.dtcms.com/a/108366.html

相关文章:

  • 【C++基础知识】 C 预处理器中的 #line 指令详解
  • RabbitMQ应用2
  • Linux系统之SFTP-搭建SFTP服务器
  • ui-tars和omni-parser使用
  • JavaScript 模块化详解( CommonJS、AMD、CMD、ES6模块化)
  • 网络安全-等级保护(等保) 1-0 等级保护制度公安部前期发文总结
  • 蓝桥杯 web 表格数据转化(组件挂载、模板字符串)
  • 【硬件视界9】网络硬件入门:从网卡到路由器
  • C# 扩展方法
  • 跨网连接vscode
  • 银联三级等保定级报告
  • CMake学习--Window下VSCode 中 CMake C++ 代码调试操作方法
  • 闭环SOTA!北航DiffAD:基于扩散模型实现端到端自动驾驶「多任务闭环统一」
  • 面基spring如何处理循环依赖问题
  • conda 清除 tarballs 减少磁盘占用 、 conda rename 重命名环境、conda create -n qwen --clone 当前环境
  • 机器学习、深度学习和神经网络
  • vscode调试python(transformers库的llama为例)
  • C#实现HiveQL建表语句中特殊数据类型的包裹
  • 用docker部署goweb项目
  • RainbowDash 的 Robot
  • C++学习笔记(三十一)——map
  • Git的基础使用方法
  • 微信小程序唤起app
  • 【Docker】使用Docker快速部署n8n和unclecode/crawl4ai
  • PEFT实战(一)——LoRA
  • 大模型学习一:deepseek api 调用实战以及参数介绍
  • 【动手学深度学习】#7 现代卷积神经网络
  • C++多态:从青铜九鼎到虚函数表的千年演化密码
  • Pytorch|RNN-心脏病预测
  • 文件分享系统--使用AI Trae开发前后端