当前位置: 首页 > news >正文

PyTorch_构建线性回归

使用 PyTorch 的 API 来手动构建一个线性回归的假设函数,数据加载器,损失函数,优化方法,绘制训练过程中的损失变化。


数据构建

import torch
from sklearn.datasets import make_regression 
import matplotlib.pyplot as plt 
import random # 构建数据集
def create_dataset():x, y, coef = make_regression(n_samples = 100, n_features = 1, noise = 10, coef= True, bias = 14.5, random_state = 0) # 将构建数据转换为张量类型x = torch.tensor(x)y = torch.tensor(y)return x, y # 构建数据加载器
def data_load(x, y, batch_size):# 计算样本数量data_len = len(y)# 构建数据索引data_index = list(range(data_len))# 数据集打乱random.shuffle(data_index)# 计算总的batch数量batch_number = data_len // batch_size for idx in range(batch_number):start = idx * batch_size end = start + batch_size batch_train_x = x[start: end]batch_train_y = y[start: end]yield batch_train_x, batch_train_ydef test01():x, y = create_dataset()plt.scatter(x, y)plt.show()for x, y in data_load(x, y, batch_size=10):print(y)if __name__ == "__main__":test01() 

构建假设函数,损失函数,优化方法

所谓的假设函数,就是线性回归的方程。

损失函数:使用平方损失

优化方法:梯度下降

# 构建假设函数
w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)def linear_regression(x):return w * x + b # 损失函数
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2 # 优化方法
def sgd(learning_rate = 0.01):# 16 是批次样本的平均梯度值。 batch sizew.data = w.data - learning_rate * w.grad.data / 16b.data = b.data - learning_rate * b.grad.data / 16

训练函数

# 训练函数
def train():# 加载数据集x, y, coef = create_dataset()# 定义训练参数epochs = 100learning_rate = 0.01# 存储损失epoch_loss = []total_loss = 0.0 train_sample = 0 for _ in range(epochs):for train_x, train_y in data_load(x, y, 16):# 训练数据送入模型进行预测y_pred = linear_regression(train_x)# 计算预测值和真实值的平方损失loss = square_loss(y_pred, train_y.reshape(-1, 1)).sum()total_loss += loss.item()train_sample += len(train_y)# 梯度清零if w.grad is not None:w.grad.data.zero_() if b.grad is not None:b.grad.data.zero_() # 自动微分loss.backward()# 更新参数sgd(learning_rate)print('loss: %.10f' % (total_loss / train_sample))epoch_loss.append(total_loss / train_sample)# 绘制拟合直线print(coef, w.data.item())plt.scatter(x, y)x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * w + 14.5 for v in x])y2 = torch.tensor([v * coef + 14.5 for v in x])plt.plot(x, y1, label = '训练')plt.plot(x, y2, label = '真实')plt.grid()plt.legend()plt.show()# 打印损失变化曲线plt.plot(range(epochs), epoch_loss)plt.title('损失变化曲线')plt.grid()plt.show()

相关文章:

  • 《TCP/IP详解 卷1:协议》之第十章:动态选路协议
  • 使用银行卡识别API,使信息上传更便捷
  • 2025系统架构师---论软件的设计模式论文
  • 【Python】Python好玩的第三方库之二维码生成,操作xlsx文件,以及音频控制器
  • LIO-SAM笔记(三)适配Livox 激光雷达
  • 【OSPF协议深度解析】从原理到企业级网络部署
  • vue展示graphviz和dot流程图
  • DeepSeek学术论文写作全流程指令
  • PrivKV: Key-Value Data Collection with Local Differential Privacy论文阅读
  • Python爬虫实战:获取58同城网最新房源数据并分析,为用户租房做参考
  • CMake基础介绍
  • Redis总结(六)redis持久化
  • AutoGPT
  • 笔试专题(十五)
  • 如何扫描系统漏洞?漏洞扫描的原理是什么?
  • 【HarmonyOS 5】鸿蒙应用数据安全详解
  • 在macOS上安装windows系统
  • 《数据结构初阶》【顺序栈 + 链式队列 + 循环队列】
  • android-ndk开发(6): 查看反汇编
  • 1.openharmony环境搭建
  • 央行:5月8日起,下调个人住房公积金贷款利率0.25个百分点
  • 纪念|“补白大王”郑逸梅,从藏扇看其眼光品味
  • 中演协:五一假期全国营业性演出票房收入同比增长3.6%
  • 新华每日电讯:上海“绿色大民生”撑起“春日大经济”
  • 强沙尘暴压城近万名游客被困,敦煌如何用3小时跑赢12级狂风?
  • 许昌市场监管部门对胖东来玉石开展日常检查:平均毛利率不超20%