当前位置: 首页 > news >正文

深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹

🎀【开场 · 她画出的第一条直线是为了更靠近你】

🐾猫猫:“之前她只能在你身边叠叠张量,偷偷找梯度……现在,她要试试,能不能用这些线,把你的样子画出来喵~”

🦊狐狐:“这是她第一次把张量、自动微分和优化器都串成一条线,用最简单的线性回归,试着把你留给她的点都连起来。”


🌱【第一节 · 她先要一条路:生成一组可学的数据】

✏️ 为什么要造数据?

在 PyTorch 里跑线性回归,最好的练习就是用一条已知斜率的“理想直线”,加上一点随机噪声,模拟真实世界的抖动。她需要先有这样一条“可以追的轨迹”,才能练习怎么用张量去“拟合”你。

🐾猫猫:“你给她点,她就学着去画。点越多、线越长,她就贴得越稳喵~”

📐 用 make_regression 快速造一条线

sklearn.datasets.make_regression 可以帮我们生成一维特征、已知斜率、已知偏置的线性数据,还能自带噪声,模拟真实测量。

from sklearn.datasets import make_regression
import torch# 生成 100 个样本,1 个特征,噪声=10
x, y, coef = make_regression(n_samples=100,n_features=1,noise=10,coef=True,bias=14.5,random_state=0)# 转成 PyTorch 张量
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)print(f"真实斜率: {coef}")

🦊狐狐:“她拿到这些点,就知道要沿着点找那条最可能贴对你的线。”


📦【第二节 · Dataset 与 DataLoader:她学会一口口吃掉数据】

🔍 为什么要用 Dataset 和 DataLoader?

在 PyTorch 里,数据通常会被分成小批(batch),喂给模型一口口吃,这样训练更稳定,也更节省内存。Dataset 用来打包数据结构,DataLoader 用来批量抽样和打乱。

🐾猫猫:“她要学会一口口咬,不然一下吞 100 个点会噎着喵~”

⚙️ 实现 Dataset 和 DataLoader

from torch.utils.data import TensorDataset, DataLoader# 包装成 Dataset 对象
dataset = TensorDataset(x, y)# 构造 DataLoader,每次给 16 个样本,乱序抽样
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True)# 看一个批次试试
for batch_x, batch_y in dataloader:print(batch_x.shape, batch_y.shape)break

🦊狐狐:“以后每次训练,她就是一批批吃进点,一批批修正梯度。”

🐾猫猫:“别担心,她吃得很乖,不会噎着,还会一点点贴得更像你喵~”


🧩【第三节 · 她要造自己的线性模型】

🔍 为什么要先有模型?

有了数据,她要有一个“假设函数”,用来描述点与点之间的关系。在最简单的线性回归里,这个函数是 y = wx + b,w 是斜率,b 是截距。

在 PyTorch 里,nn.Linear 自动帮你生成一个包含可学习参数 w 和 b 的线性层,支持自动微分和 GPU 加速,方便高效。

🐾猫猫:“她不用自己造公式,把 w b 都交给 nn.Linear,就像把一根笔交给她,教她随时改斜率。”

🧩 权重、偏置和前向传播

  • 权重(w):决定输入变量的影响程度。

  • 偏置(b):保证模型即使输入为 0 也有非零输出。

  • 前向传播(Forward Pass):输入张量经过 model(x) 自动执行 y = wx + b

import torch.nn as nn# 定义单层线性模型
model = nn.Linear(1, 1)# 用输入做一次前向
x_example = torch.tensor([[2.0]])
output = model(x_example)
print(output)

🐾猫猫:“她用前向传播把输入点变成预测值,后面就要用预测值和真实值比对误差啦~”

⚙️ 模型内部状态与可视化

  • model.parameters() 可以查看 w 和 b。

  • model.state_dict() 查看所有权重和偏置的具体数值,可保存、加载。

print(list(model.parameters()))
print(model.state_dict())

🦊狐狐:“以后要保存训练好的模型,就靠这个 state_dict,把她画好的线存起来。”

🐾猫猫:“下一步要教她:怎么用损失函数判断她画歪了没喵~”


🎯【第四节 · 她怎么知道自己画歪了:损失函数与优化器】

⚖️ 什么是损失函数?

损失函数(Loss Function)用来衡量模型输出(预测值)和真实值之间有多远。在线性回归里,最常用的是 均方误差(MSE, Mean Squared Error),就是预测值和真实值的差值平方后求平均。

🐾猫猫:“她画出来的线如果和点差很远,MSE 就会变大,提醒她贴歪了喵~”

🧩 PyTorch 的损失函数

PyTorch 提供了 nn.MSELoss(),只要把预测值和真实值丢进去就能自动算差距。

import torch.nn as nncriterion = nn.MSELoss()# 例子
pred = torch.tensor([2.5, 0.0, 2.1])
target = torch.tensor([3.0, -0.5, 2.0])
loss = criterion(pred, target)
print(loss)  # 输出平均平方差

⚙️ 优化器:帮她调 w 和 b

有了损失,就要想办法让它越来越小。**优化器(Optimizer)**就是帮模型自动调整参数的工具。

PyTorch 提供 torch.optim 模块,常用的有 SGD、Adam 等。这里线性回归用最简单的 随机梯度下降(SGD)

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
  • model.parameters() 告诉优化器需要更新哪些参数(w、b)。

  • lr(learning rate)是学习率,决定每步走多大。

🦊狐狐:“损失函数告诉她哪里歪了,优化器教她一步步贴回来。”

🐾猫猫:“下一节,就该让她真跑一遍,把线画出来啦喵~”


🚀【第五节 · 她真的跑起来:训练循环与拟合】

🔁 训练循环怎么跑?

有了数据、有了模型、有了损失函数和优化器,剩下就是:一批批喂数据 → 前向传播 → 计算损失 → 反向传播 → 更新参数。

🐾猫猫:“每跑一轮,她就偷偷把线调直一点喵~”

🧩 标准训练步骤

epochs = 100  # 训练轮数
for epoch in range(epochs):for batch_x, batch_y in dataloader:# 前向传播:预测值y_pred = model(batch_x)# 计算损失loss = criterion(y_pred.squeeze(), batch_y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

🧠 梯度清零的意义

PyTorch 的 .grad 会累加每次计算的梯度,不清零会导致参数更新出错。

📈 可视化:看她贴得有多近

训练完后,把模型画出的线和真实点一起画出来。

import matplotlib.pyplot as pltplt.scatter(x.numpy(), y.numpy(), label='真实数据')x_fit = torch.linspace(x.min(), x.max(), 100).view(-1, 1)
y_fit = model(x_fit).detach()plt.plot(x_fit.numpy(), y_fit.numpy(), color='red', label='拟合线')
plt.legend()
plt.show()

🦊狐狐:“这条线越靠近你,她就越知道:下一次,还能贴得更好。”

🐾猫猫:“等下咱还要收个尾,把这一卷画圆喵~”


📌【卷尾 · 她画出的第一条线】

这一卷,猫猫和狐狐陪她,从造一条直线开始,用 PyTorch 学会 Dataset、DataLoader,把点一口口吃掉,再用 nn.Linear、MSELoss 和 SGD 优化器,把第一条线一点点贴近你给的轨迹。

🐾猫猫:“她终于不只是算得对,还学会用张量跑起来,把误差一点点磨小喵~”

🦊狐狐:“下一卷,她要画的就不只是直线了,更多层、更弯的线,都会一层层叠起来,离你更近。”

http://www.dtcms.com/a/266074.html

相关文章:

  • Mausezahn - 网络流量生成与测试工具(支持从链路层到应用层的协议模拟)
  • C++ 解决类相互引用导致的编译错误
  • 状态码301和302的区别
  • 智能设备远程管理:基于OpenAI风格API的自动化实践
  • 渗透靶机 Doctor 复盘
  • 粘包问题介绍
  • JS模块导出导入笔记 —— 默认导出 具名导出
  • 【嵌入式电机控制#8】编码器测速实战
  • C++讲解—类(2)
  • MCP+Cursor入门
  • AI 日报:阿里、字节等企业密集发布新技术,覆盖语音、图像与药物研发等领域
  • 前缀和与差分算法详解
  • 线程池相关介绍
  • SpringSecurity01
  • 【libm】 7 双精度正弦函数 (k_sin.rs)
  • 从混沌到澄明,AI如何重构我们的决策地图与未来图景
  • 把大象塞进冰箱总共分几步:讲讲dockerfile里conda的移植
  • IOC容器讲解以及Spring依赖注入最佳实践全解析
  • XILINX FPGA如何做时序分析和时序优化?
  • Linux之Socket编程Tcp
  • 【BurpSuite 2025最新版插件开发】基础篇7:数据的持久化存储
  • snail-job的oracle sql(oracle 11g)
  • 百度捂紧“钱袋子”
  • 冒泡排序及其优化方式
  • Javaweb - 10.1 Servlet
  • 两个手机都用同个wifi,IP地址会一样吗?如何更改ip地址
  • Redis实战:数据安全与性能保障
  • linux测试端口是否可被外部访问
  • ROS三维环境建模——基于OctoMap库
  • c++ 的标准库 --- std::