当前位置: 首页 > news >正文

【深度学习3】线性回归的简洁实现

# 使用 PyTorch 实现线性回归
import torch
from torch import nn
from d2l import torch as d2l  # 使用 d2l 的 PyTorch 版本# -------------------------------
# 1. 生成数据集
# -------------------------------
# true_w: 真实权重向量
# true_b: 真实偏置
# synthetic_data: 生成线性模型 y = Xw + b + 噪声 的数据集
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)# -------------------------------
# 2. 构造数据迭代器
# -------------------------------
# batch_size: 每个小批量的样本数
# load_array: 将 features 和 labels 封装为 PyTorch DataLoader
#           shuffle=True 表示在训练时每个 epoch 打乱样本顺序
batch_size = 10
def load_array(data_arrays, batch_size, is_train=True):"""构造一个 PyTorch 数据迭代器"""dataset = torch.utils.data.TensorDataset(*data_arrays)return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)data_iter = load_array((features, labels), batch_size)# -------------------------------
# 3. 定义模型
# -------------------------------
# nn.Sequential: 顺序容器,将多个层按顺序组合
# nn.Linear(2, 1): 全连接层
#     输入特征数=2
#     输出特征数=1(标量预测)
net = nn.Sequential(nn.Linear(2, 1))# 初始化模型参数
# 权重从均值为0,标准差为0.01的正态分布初始化
# 偏置初始化为0
nn.init.normal_(net[0].weight, mean=0, std=0.01)
nn.init.zeros_(net[0].bias)# -------------------------------
# 4. 定义损失函数
# -------------------------------
# 均方误差(MSELoss),对每个样本损失取平均
loss = nn.MSELoss()# -------------------------------
# 5. 定义优化算法
# -------------------------------
# SGD: 小批量随机梯度下降
# net.parameters(): 要优化的模型参数集合(权重和偏置)
# lr: 学习率
trainer = torch.optim.SGD(net.parameters(), lr=0.03)# -------------------------------
# 6. 训练
# -------------------------------
# num_epochs: 迭代周期数
# 对于每一个小批量:
#   1. 通过 net(X) 得到预测 y_hat(前向传播)
#   2. 计算损失 l
#   3. 梯度清零
#   4. 反向传播计算梯度
#   5. 通过优化器更新参数
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)     # 前向传播计算小批量损失trainer.zero_grad()     # 清空梯度l.backward()            # 反向传播计算梯度trainer.step()          # 更新模型参数# 每个 epoch 输出整个训练集的平均损失train_l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {float(train_l):f}')# -------------------------------
# 7. 输出参数误差
# -------------------------------
# 比较训练得到的参数与真实参数的差距
w = net[0].weight.data
b = net[0].bias.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
print('b的估计误差:', true_b - b)

结果:

w 的估计误差b的估计误差
tensor([-0.0005, 0.0005])tensor([-3.5286e-05])
http://www.dtcms.com/a/550743.html

相关文章:

  • 招商网站建设哪家好济南中桥信息做的小语种网站怎么样
  • 可视化建网站网站关键词和描述
  • 无人机巡护青海湖,AI如何守护西部生态与能源安全?
  • wordpress短代码可视化常州seo网络推广
  • 网站免费做app专门做萝莉视频网站
  • 呼和浩特网站建设SEO优化做网站的目的是什么
  • python进阶教程3:内存池、内存分配优化
  • 网站流程图容桂品牌网站建设优惠
  • 程序与工业:从附庸到共生,在AI浪潮下的高维重构
  • 免费的制作手机网站平台wordpress dux主题设置首页
  • 口碑好的网站定制公司wordpress mdtf
  • 网站建设 开票全国网络公司大全
  • 站群系统破解版昆明百度推广优化排名
  • 快速建站系统网站游戏网站怎么自己做
  • 公司网站要备案么上海ktv最新通知
  • Rust 中的路由匹配与参数提取:类型安全的 HTTP 路径解析艺术
  • 电商网站开发 文献综述wordpress插件汉化下载
  • 最常用的网站推广方式代做网站收费标准
  • Slicer中VolumeNode与切片视图实现的机制
  • 仓颉编程(22)扩展
  • 电子商城网站开发项目描述wordpress图片轮播
  • wordpress建站多个域名网络运营是干什么的
  • 高端网站建设制作网站过期了怎么办
  • 专业公司网站建设公司做网站找谁公司做网站找谁
  • Java基础——常用API
  • 【001】Java开发环境
  • linux网站建设技术指南推广普通话的标语
  • 大专人力资源专业毕业生能做 HR 助理吗?入门条件有什么?【一文说清楚】
  • 淘宝网站怎么做的邯郸手机网站开发价格
  • 【王树森深度强化学习】基本概念 Deep Reinforcement Learning (1/5)