当前位置: 首页 > news >正文

【深度学习2】线性回归的从零开始实现

这段代码完整地实现了一个 线性回归(Linear Regression)从零开始训练的全过程。

  • 人工合成数据
  • 小批量读取
  • 模型定义
  • 损失函数
  • 梯度下降优化
  • 训练与参数评估
import matplotlib
import matplotlib.pyplot as plt
import torch
import d2l
import random
import torch
from d2l import torch as d2l
torch.manual_seed(42)  # 固定随机种子,确保每次生成的随机数据一致
# 只固定了 PyTorch 的随机数,而 random.shuffle() 使用的是 Python 自带的随机库,它有自己独立的随机种子,所以每次运行数据打乱顺序可能仍然不同,会导致结果略有偏差# 生成数据集
def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))   # 随机生成输入特征y = torch.matmul(X, w) + b                       # 线性关系y += torch.normal(0, 0.01, y.shape)              # 加入高斯噪声return X, y.reshape((-1, 1))# 设定真实权重和偏置,生成训练集
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)print('features:', features[0],'\nlabel:', labels[0])# 数据可视化,画出第二个特征和标签的散点图,看是否是线性关系
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
plt.show()  # ⬅️ 一定要加上# 读取数据集
# 训练模型时要对数据集进行遍历。
# 把整个数据集分成小批次(mini-batch),每次取 batch_size=10 个样本用于训练。
# 这有助于模型在梯度下降时更稳定、更高效。# 该函数能打乱数据集中的样本并以小批量方式获取数据。
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读取的,没有特定的顺序random.shuffle(indices) # 打乱样本顺序for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10# 连续地获得不同的小批量,直至遍历完整个数据集
for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break# 开始用小批量随机梯度下降优化我们的模型参数之前初始化参数模型
# 从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 开始更新这些参数,直到这些参数足够拟合我们的数据
# 每次更新都需要计算损失函数关于模型参数的梯度,有了这个梯度,我们就可以向减小损失的方向更新每个参数# 定义模型,将模型的输入和参数同模型的输出关联起来# 广播机制: 当我们用一个向量加一个标量时,标量会被加到向量的每个分量上。
def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + b# 需要计算损失函数的梯度,所以我们应该先定义损失函数
# 这里使用平方损失函数。 在实现中,我们需要将真实值y的形状转换为和预测值y_hat的形状相同。
# 损失函数越小说明预测越接近真实值
def squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定义优化算法(随机梯度下降 SGD)
# 该函数接受模型参数集合、学习速率和批量大小作为输入
# 每一步更新的大小由学习速率lr决定
# 我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择
def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():               # 禁止自动求导for param in params:param -= lr * param.grad / batch_size   # 参数更新param.grad.zero_()                      # 清空梯度# 训练
# 在每次迭代中,我们读取一小批量训练样本,并通过我们的模型来获得一组预测。
# 计算完损失后,我们开始反向传播,存储每个参数的梯度。
# 最后,我们调用优化算法sgd来更新模型参数。
# 在每个迭代周期(epoch)中,我们使用data_iter函数遍历整个数据集, 并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)
# 迭代周期个数num_epochs和学习率lr都是超参数,分别设为3和0.03
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_lossfor epoch in range(num_epochs):   # num_epochs:完整遍历数据集的次数for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # X和y的小批量损失# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,# 并以此计算关于[w,b]的梯度l.sum().backward() # 反向传播计算梯度sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') # 每个 epoch 打印平均损失print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

结果:

w 的估计误差b的估计误差
tensor([ 0.0005, -0.0009])tensor([8.1062e-06])

结果差异性原因:
✅ 原因一:随机性导致差异
生成数据时加了噪声:y += torch.normal(0, 0.01, y.shape)
每次运行都会稍微不同(即使固定 torch.manual_seed(42),Python 的 random.shuffle() 仍会造成差别)。
数据顺序变化 → 梯度下降路径略有不同 → 最终结果略有差别。

✅ 原因二:浮点误差和运算顺序不同
在 GPU/CPU 或 PyTorch 版本不同的情况下,浮点运算的微小误差可能放大到 1e-4 级别。

✅ 原因三:打印时显示的数值精度
PyTorch 默认显示 4~6 位小数,两者的误差在 1e-3 以内,本质上都是完美拟合。

✅ 结论
差异来自随机打乱与噪声,不影响模型效果。

函数总结:

模块作用
synthetic_data()生成带噪线性数据
data_iter()小批量数据读取
linreg()线性回归模型
squared_loss()均方误差损失函数
sgd()随机梯度下降参数更新
训练循环反复更新参数,最小化损失
最终输出拟合到接近真实参数
http://www.dtcms.com/a/545366.html

相关文章:

  • LeetCode第2题:两数相加及其变种(某大厂面试原题)
  • Java 字符编码全解析:从乱码根源到 Unicode 实战指南
  • SpringBoot 高效工具类大全
  • 自己做网站用软件wordpress电商优秀
  • 百度网站建设中的自由容器网站用哪个数据库
  • 入侵检测系统——HIDS和NIDS的区别
  • C语言多进程创建和回收
  • 仓颉编程语言:控制流语句详解(if/else)
  • 专利撰写与申请核心要点简报
  • AI搜索引擎num=100参数移除影响深度分析:内容标识与准确性变化
  • NJU-SME 人工智能(三) -- 正则化 + 分类 + SVM
  • 【数据库】表的设计
  • 深圳制作网站建设推广第一网站ppt模板
  • 点网站建设广州专业网站建设哪家公司好
  • 仓颉语言构造函数深度实践指南
  • DTAS 3D-尺寸公差分析定制化服务与解决方案的专家-棣拓科技
  • 永康营销型网站建设wordpress自定义作者连接
  • linux NFS(网络文件系统)挂载完整指南
  • 数字营销软件完整指南|CRM、CDP、自动化平台 2025
  • 企业级建模平台Enterprise Architect如何自动化生成报告
  • Chat2DB 学习笔记
  • 外国做爰网站小程序问答库
  • 关于网站建设方案的案例石家庄建站凡科
  • LeetCode 410 - 分割数组的最大值
  • Kotlin数据结构性能全解析
  • 搜索引擎网站优化和推广方案网站建设招标合同要求
  • coco json 分类标注工具源代码
  • 重学JS-012 --- JavaScript算法与数据结构(十二)正则表达式
  • 自己做网站还是公众号爱链网中可以进行链接买卖
  • maven中properties和dependencys标签的区别