当前位置: 首页 > news >正文

pytorch简单线性回归模型

模型五步走

1、获取数据

     1. 数据预处理

     2.归一化

     3.转换为张量

2、定义模型

3、定义损失函数和优化器

4、模型训练

5、模型评估和调优

调优方法

6、可视化(可选)

示例代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, r2_score# print(np.__config__.show())##1、生成数据
np.random.seed(42)
def generate_data(x, slope=2.0, intercept=1.0, noise_std=2.0):"""生成带有噪声的线性数据 y = a*x + b + ε:param x: 输入特征:param slope: 斜率 a:param intercept: 截距 b:param noise_std: 噪声标准差:return: y 数据,以及真实参数 (slope, intercept)"""y = slope * x + intercept + np.random.randn(len(x)) * noise_stdreturn y, (slope, intercept)# 使用示例
x = np.linspace(0, 10, 100)
y, true_params = generate_data(x, slope=2, intercept=1, noise_std=2)
print("真实参数:", true_params)#归一化
x_norm = (x - x.min()) / (x.max() - x.min())
y_norm = (y - y.min()) / (y.max() - y.min())#转换为pytorch张量
x_tensor = torch.tensor(x_norm, dtype=torch.float32).view(-1, 1)
y_tensor = torch.tensor(y_norm, dtype=torch.float32).view(-1, 1)#2、定义模型
class LinearRegression(nn.Module):def __init__(self,input_size,output_size):super(LinearRegression, self).__init__()self.linear = nn.Linear(input_size,output_size)def forward(self, x):out = self.linear(x)return out#实例化模型
model = LinearRegression(1,1)#3、定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)#4、训练模型
num_epochs = 10000
torch.nn.init.xavier_normal_(model.linear.weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)for epoch in range(num_epochs):#前向传播outputs = model(x_tensor)loss = criterion(outputs,y_tensor)#反向传播optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 1000 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')#5、输出测试结果
print('训练完成!')
print(f'权重: {model.linear.weight.item():.4f}, 偏置: {model.linear.bias.item():.4f}')#6、可视化
predicted = model(x_tensor).detach().numpy()
# 反归一化
predicted_unscaled = predicted * (y.max() - y.min()) + y.min()
y_true_unscaled = y_tensor.numpy() * (y.max() - y.min()) + y.min()# 评估指标
mae = mean_absolute_error(y_true_unscaled, predicted_unscaled)
r2 = r2_score(y_true_unscaled, predicted_unscaled)print(f'均方误差(MSE): {loss.item():.4f}')
print(f'平均绝对误差(MAE): {mae:.4f}')
print(f'R²决定系数(R²): {r2:.4f}')
r22 = r2_score(y_tensor.numpy(), predicted)
print(f"Model R² score: {r22:.4f}")#中文乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(x_tensor, y_tensor, 'ro', label='Original data')
plt.plot(x_tensor, predicted, label='拟合曲线')
plt.legend()
plt.show()

相关文章:

  • 黑马点评--缓存更新策略及案例实现
  • ubuntu脚本常用命令
  • Halcon 图像预处理②
  • AI时代新词-数字孪生(Digital Twin)
  • 并发的产生及对应的解决方案之服务架构说明
  • 大模型Agent
  • [开源项目] 一款功能强大的超高音质音乐播放器
  • 无网络docker镜像迁移
  • 曲线匹配,让数据点在匹配数据的一侧?
  • ADS学习笔记(五) 谐波平衡仿真
  • 电子电路原理第十七章(线性运算放大器电路的应用)
  • 开疆智能Profinet转Profibus网关连接韦普泰克工业称重仪表配置案例
  • 【Qt开发】输入类控件
  • Python 字符串相似度计算:方法、应用与实践
  • WeakAuras Lua Script [ICC BOSS 11 - Sindragosa]
  • ROS2学习(10)------ROS2参数
  • STM32F103_Bootloader程序开发03 - 启动入口与升级模式判断(boot_entry.c与boot_entry.h)
  • SOC-ESP32S3部分:13-定时器
  • 多查询检索在RAG中的应用及为什么平均嵌入向量效果好
  • 【蓝桥杯嵌入式】【模块】八、UART相关配置及代码模板
  • 优秀的网站建设/sem工作原理
  • 河南网站建设企业/提高网站排名软件
  • 通辽做网站的公司/成都优化网站哪家公司好
  • 申请一个网站空间/可以免费发广告的网站有哪些
  • 微信小店可以做分类网站/中国十大企业培训公司
  • 网站代建设费用/品牌运营推广方案