当前位置: 首页 > news >正文

pytorch简单线性回归模型

模型五步走

1、获取数据

     1. 数据预处理

     2.归一化

     3.转换为张量

2、定义模型

3、定义损失函数和优化器

4、模型训练

5、模型评估和调优

调优方法

6、可视化(可选)

示例代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, r2_score# print(np.__config__.show())##1、生成数据
np.random.seed(42)
def generate_data(x, slope=2.0, intercept=1.0, noise_std=2.0):"""生成带有噪声的线性数据 y = a*x + b + ε:param x: 输入特征:param slope: 斜率 a:param intercept: 截距 b:param noise_std: 噪声标准差:return: y 数据,以及真实参数 (slope, intercept)"""y = slope * x + intercept + np.random.randn(len(x)) * noise_stdreturn y, (slope, intercept)# 使用示例
x = np.linspace(0, 10, 100)
y, true_params = generate_data(x, slope=2, intercept=1, noise_std=2)
print("真实参数:", true_params)#归一化
x_norm = (x - x.min()) / (x.max() - x.min())
y_norm = (y - y.min()) / (y.max() - y.min())#转换为pytorch张量
x_tensor = torch.tensor(x_norm, dtype=torch.float32).view(-1, 1)
y_tensor = torch.tensor(y_norm, dtype=torch.float32).view(-1, 1)#2、定义模型
class LinearRegression(nn.Module):def __init__(self,input_size,output_size):super(LinearRegression, self).__init__()self.linear = nn.Linear(input_size,output_size)def forward(self, x):out = self.linear(x)return out#实例化模型
model = LinearRegression(1,1)#3、定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)#4、训练模型
num_epochs = 10000
torch.nn.init.xavier_normal_(model.linear.weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)for epoch in range(num_epochs):#前向传播outputs = model(x_tensor)loss = criterion(outputs,y_tensor)#反向传播optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 1000 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')#5、输出测试结果
print('训练完成!')
print(f'权重: {model.linear.weight.item():.4f}, 偏置: {model.linear.bias.item():.4f}')#6、可视化
predicted = model(x_tensor).detach().numpy()
# 反归一化
predicted_unscaled = predicted * (y.max() - y.min()) + y.min()
y_true_unscaled = y_tensor.numpy() * (y.max() - y.min()) + y.min()# 评估指标
mae = mean_absolute_error(y_true_unscaled, predicted_unscaled)
r2 = r2_score(y_true_unscaled, predicted_unscaled)print(f'均方误差(MSE): {loss.item():.4f}')
print(f'平均绝对误差(MAE): {mae:.4f}')
print(f'R²决定系数(R²): {r2:.4f}')
r22 = r2_score(y_tensor.numpy(), predicted)
print(f"Model R² score: {r22:.4f}")#中文乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(x_tensor, y_tensor, 'ro', label='Original data')
plt.plot(x_tensor, predicted, label='拟合曲线')
plt.legend()
plt.show()


文章转载自:

http://dm5Wvkqd.mnwmj.cn
http://PU73fPCz.mnwmj.cn
http://B3w29pf7.mnwmj.cn
http://9wEM9BFO.mnwmj.cn
http://esfMCS3j.mnwmj.cn
http://s3on7gQ9.mnwmj.cn
http://WoZqZijr.mnwmj.cn
http://olkMWaj7.mnwmj.cn
http://HvtXtugJ.mnwmj.cn
http://VzOhmqq6.mnwmj.cn
http://Dv1bPyt3.mnwmj.cn
http://sjJKu39b.mnwmj.cn
http://Q1ZXaVig.mnwmj.cn
http://mELk4gIK.mnwmj.cn
http://bmIxrbtW.mnwmj.cn
http://hUpbaZVu.mnwmj.cn
http://78uMHdQK.mnwmj.cn
http://vg4Lu6CC.mnwmj.cn
http://nYcdDbrD.mnwmj.cn
http://gmPeLx4B.mnwmj.cn
http://dpnDQP4U.mnwmj.cn
http://HC5Oc1rz.mnwmj.cn
http://McIKaNuR.mnwmj.cn
http://fsmW2FKl.mnwmj.cn
http://M2DXSCw3.mnwmj.cn
http://QC0xVRWH.mnwmj.cn
http://GxHxmxLc.mnwmj.cn
http://nviprBf5.mnwmj.cn
http://FWbYPXMd.mnwmj.cn
http://Im81X3cF.mnwmj.cn
http://www.dtcms.com/a/215042.html

相关文章:

  • 黑马点评--缓存更新策略及案例实现
  • ubuntu脚本常用命令
  • Halcon 图像预处理②
  • AI时代新词-数字孪生(Digital Twin)
  • 并发的产生及对应的解决方案之服务架构说明
  • 大模型Agent
  • [开源项目] 一款功能强大的超高音质音乐播放器
  • 无网络docker镜像迁移
  • 曲线匹配,让数据点在匹配数据的一侧?
  • ADS学习笔记(五) 谐波平衡仿真
  • 电子电路原理第十七章(线性运算放大器电路的应用)
  • 开疆智能Profinet转Profibus网关连接韦普泰克工业称重仪表配置案例
  • 【Qt开发】输入类控件
  • Python 字符串相似度计算:方法、应用与实践
  • WeakAuras Lua Script [ICC BOSS 11 - Sindragosa]
  • ROS2学习(10)------ROS2参数
  • STM32F103_Bootloader程序开发03 - 启动入口与升级模式判断(boot_entry.c与boot_entry.h)
  • SOC-ESP32S3部分:13-定时器
  • 多查询检索在RAG中的应用及为什么平均嵌入向量效果好
  • 【蓝桥杯嵌入式】【模块】八、UART相关配置及代码模板
  • [De1CTF 2019]SSRF Me
  • 今日行情明日机会——20250526
  • Redis批量删除Key的三种方式
  • 【杂谈】------使用 __int128 处理超大整数计算
  • MyBatis深度解析:XML/注解配置与动态SQL编写实战
  • TinyVue v3.23.0 正式发布:增加 NumberAnimation 数字动画组件、支持全局配置组件的 props
  • FreeRTOS 在物联网传感器节点的应用:低功耗实时数据采集与传输方案
  • 资源回收:为地球减负,共创绿色未来
  • 超临界机组协调控制系统建模项目开发笔记
  • ubuntu 22.04 配置静态IP、网关、DNS