当前位置: 首页 > news >正文

使用python实现线性回归

一、概述

本代码主要演示了如何使用 Python 的 numpymatplotlib 和 sklearn 库进行简单线性回归分析。通过生成模拟数据,训练线性回归模型,对模型进行评估,并将结果可视化,帮助用户理解线性回归的基本原理和操作流程。

二、依赖库

  1. numpy:用于数值计算和数组操作,如生成随机数和处理数组数据。
  2. matplotlib.pyplot:用于数据可视化,绘制散点图和回归线。
  3. sklearn.linear_model.LinearRegression:用于创建和训练线性回归模型。
  4. sklearn.metrics.mean_squared_error 和 sklearn.metrics.r2_score:分别用于计算均方误差(MSE)和 \(R^2\) 分数,评估模型的性能。

三、代码详细解释

1. 导入必要的库

收起

python

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

  • 导入 numpy 库并将其别名为 np,方便后续使用。
  • 导入 matplotlib.pyplot 库并将其别名为 plt,用于绘图。
  • 从 sklearn.linear_model 模块中导入 LinearRegression 类,用于创建线性回归模型。
  • 从 sklearn.metrics 模块中导入 mean_squared_error 和 r2_score 函数,用于评估模型性能。

2. 设置中文字体

收起

python

plt.rcParams['font.family'] = 'SimSun'

  • 设置 matplotlib 的字体为宋体,确保在绘图时可以正常显示中文。

3. 生成模拟数据

收起

python

np.random.seed(42)  # 固定随机种子
X = np.random.rand(100, 1) * 10  # 生成 100 个 0~10 之间的特征值
y = 3 * X + 5 + np.random.randn(100, 1) * 2  # y = 3x + 5 + 噪声

  • np.random.seed(42):固定随机种子,确保每次运行代码时生成的随机数相同,方便结果的复现。
  • X = np.random.rand(100, 1) * 10:使用 np.random.rand 函数生成一个形状为 (100, 1) 的数组,数组中的元素是 0 到 1 之间的随机数,然后将其乘以 10,得到 100 个 0 到 10 之间的特征值。
  • y = 3 * X + 5 + np.random.randn(100, 1) * 2:根据真实方程 \(y = 3x + 5\) 生成目标值,并添加高斯噪声(使用 np.random.randn 函数生成),模拟真实世界中的数据。

4. 创建和训练线性回归模型

收起

python

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X, y)

  • model = LinearRegression():创建一个 LinearRegression 类的实例,即一个线性回归模型。
  • model.fit(X, y):使用生成的特征值 X 和目标值 y 对模型进行训练,让模型学习 X 和 y 之间的线性关系。

5. 模型预测

收起

python

# 预测
y_pred = model.predict(X)

  • y_pred = model.predict(X):使用训练好的模型对特征值 X 进行预测,得到预测的目标值 y_pred

6. 获取模型参数

收起

python

# 获取模型参数
slope = model.coef_[0][0]  # 斜率
intercept = model.intercept_[0]  # 截距

  • slope = model.coef_[0][0]:从模型的系数(斜率)数组中获取斜率值。
  • intercept = model.intercept_[0]:从模型的截距数组中获取截距值。

7. 打印模型参数和评估指标

收起

python

# 打印模型参数和评估指标
print(f"真实方程: y = 3x + 5")
print(f"学习到的方程: y = {slope:.2f}x + {intercept:.2f}")
print(f"均方误差 (MSE): {mean_squared_error(y, y_pred):.2f}")
print(f"R² 分数: {r2_score(y, y_pred):.2f}")

  • 打印真实方程和模型学习到的方程,方便对比。
  • 使用 mean_squared_error 函数计算均方误差(MSE),衡量模型预测值与真实值之间的平均误差。
  • 使用 r2_score 函数计算 \(R^2\) 分数,评估模型对数据的拟合程度,\(R^2\) 分数越接近 1 表示模型拟合效果越好。

8. 可视化结果

收起

python

# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6)
plt.plot(X, y_pred, color='red', linewidth=2, label='回归线')
plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系')
plt.xlabel('X')
plt.ylabel('y')
plt.title('线性回归示例')
plt.legend()
plt.grid(True)
plt.show()

  • plt.figure(figsize=(10, 6)):创建一个大小为 (10, 6) 的图形窗口。
  • plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6):绘制原始数据的散点图,颜色为蓝色,设置透明度为 0.6。
  • plt.plot(X, y_pred, color='red', linewidth=2, label='回归线'):绘制模型的回归线,颜色为红色,线宽为 2。
  • plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系'):绘制真实的线性关系,颜色为绿色,线型为虚线。
  • plt.xlabel('X') 和 plt.ylabel('y'):设置 x 轴和 y 轴的标签。
  • plt.title('线性回归示例'):设置图形的标题。
  • plt.legend():显示图例,方便区分不同的图形元素。
  • plt.grid(True):显示网格线,增强图形的可读性。
  • plt.show():显示绘制好的图形。

四、总结

通过本代码示例,我们可以看到如何使用 sklearn 库进行简单线性回归分析,包括数据生成、模型训练、预测、评估和可视化。用户可以根据需要修改代码中的参数,如随机种子、数据规模、噪声水平等,进一步探索线性回归的特性。

完整代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score


plt.rcParams['font.family'] = 'SimSun'

# 生成模拟数据
np.random.seed(42)  # 固定随机种子
X = np.random.rand(100, 1) * 10  # 生成 100 个 0~10 之间的特征值
y = 3 * X + 5 + np.random.randn(100, 1) * 2  # y = 3x + 5 + 噪声

# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X, y)

# 预测
y_pred = model.predict(X)

# 获取模型参数
slope = model.coef_[0][0]  # 斜率
intercept = model.intercept_[0]  # 截距

# 打印模型参数和评估指标
print(f"真实方程: y = 3x + 5")
print(f"学习到的方程: y = {slope:.2f}x + {intercept:.2f}")
print(f"均方误差 (MSE): {mean_squared_error(y, y_pred):.2f}")
print(f"R² 分数: {r2_score(y, y_pred):.2f}")

# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6)
plt.plot(X, y_pred, color='red', linewidth=2, label='回归线')
plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系')
plt.xlabel('X')
plt.ylabel('y')
plt.title('线性回归示例')
plt.legend()
plt.grid(True)
plt.show()

相关文章:

  • 修改DOSBox的窗口大小
  • 启动你的RocketMQ之旅(四)-Producer启动和发送流程(下)
  • 国产开源AI平台Cherry Studio详解:联网搜索升级与ChatBox对比指南
  • spring.profiles.active和spring.profiles.include的使用及区别说明
  • 基于html的俄罗斯方块小游戏(附程序)
  • MCAL-I/O驱动
  • 考研408数据结构第三章(栈、队列和数组)核心易错点深度解析
  • 01_NLP基础之文本处理的基本方法
  • 附录-Python — 包下载缓慢,配置下载镜像
  • 河南理工XCPC萌新选拔赛
  • SEO长尾词优化进阶法则
  • 【3天快速入门WPF】11-附加属性
  • 绪论(3)
  • AtCoder Beginner Contest 001(A - 積雪深差、B - 視程の通報、C - 風力観測、D - 感雨時刻の整理)题解
  • 如何通过Python网络爬虫技术应对复杂的反爬机制?
  • 物联网同RFID功能形态 使用场景的替代品
  • Mac OS Homebrew更换国内镜像源(中科大;阿里;清华)
  • 数据结构秘籍(四) 堆 (详细包含用途、分类、存储、操作等)
  • 【C++】ImGui:极简化的立即模式GUI开发
  • 【数据挖掘】Matplotlib
  • php网站开发专业/如何快速搭建网站
  • 网站设计 佛山/百度seo优化是做什么的
  • 陕西住房建设厅官方网站/十大免费无代码开发软件
  • 网站源码模块/优化大师免费下载
  • 网站代码开发/杭州网站优化多少钱
  • 5星做号宿水软件的网站/天津百度关键词seo